The official PyTorch implementation for Learning Truncated Causal History Model for Video Restoration, accepted to NeurIPS 2024.
- Turtle achieves state-of-the-art results on multiple video restoration benchmarks, offering superior computational efficiency and enhanced restoration quality 🔥🔥🔥.
- 🛠️💡Model Forge: Easily design your own architecture by modifying the option file.
- You have the flexibility to choose from various types of layers—such as channel attention, simple channel attention, CHM, FHR, or custom blocks—as well as different types of feed-forward layers.
- This setup allows you to create custom networks and experiment with layer and feed-forward configurations to suit your needs.
- If you like this project, please give us a ⭐ on Github!🚀
- Oct. 10, 2024: The paper is now available on arxiv along with the code and pretrained models.
- Sept 25, 2024: Turtle is accepted to NeurIPS'2024.
- Installation
- Trained Models
- Dataset Preparation
- Training
- Evaluation
- Model Complexity and Inference Speed
- Acknowledgments
- Citation
This implementation is based on BasicSR which is an open-source toolbox for image/video restoration tasks.
python 3.9.5
pytorch 1.11.0
cuda 11.3
pip install -r requirements.txt
python setup.py develop --no_cuda_ext
You can download our trained models from Google Drive: Trained Models
To obtain the datasets, follow the official instructions provided by each dataset's provider and download them into the dataset folder. You can download the datasets for each of the task from the following links (official sources reported by their respective authors).
- Desnowing: RSVD
- Raindrops and Rainstreaks Removal: VRDS
- Night Deraining: NightRain
- Synthetic Deblurring: GoPro
- Real-World Deblurring: BSD3ms-24ms
- Denoising: DAVIS | Set8
- Real-World Super Resolution: MVSR
The directory structure, including the ground truth ('gt') for reference frames and 'blur' for degraded images, should be organized as follows:
./datasets/
└── Dataset_name/
├── train/
└── test/
├── blur
├── video_1
│ ├── Fame1
│ ....
└── video_n
│ ├── Fame1
│ ....
└── gt
├── video_1
│ ├── Fame1
│ ....
└── video_n
│ ├── Fame1
│ ....
To train the model, make sure you select the appropriate data loader in the train.py
. There are two options as follows.
-
For deblurring, denoising, deraining, etc. keep the following import line, and comment the superresolution one.
from basicsr.data.video_image_dataset import VideoImageDataset
-
For superresolution, keep the following import line, and comment the previous one.
from basicsr.data.video_super_image_dataset import VideoSuperImageDataset as VideoImageDataset
python -m torch.distributed.launch --nproc_per_node=8 --master_port=8080 basicsr/train.py -opt /options/option_file_name.yml --launcher pytorch
The pretrained models can be downloaded from the GDrive link.
To evaluate the pre-trained model use this command:
python inference.py
Adjust the function parameters in the Python file according to each task requirements:
config
: Specify the path to the option file.model_path
: Provide the location of pre-trained model.dataset_name
: Select the dataset you are using ("RSVD", "GoPro", "SR", "NightRain", "DVD", "Set8").task_name
: Choose the restoration task ("Desnowing", "Deblurring", "SR", "Deraining", "Denoising").model_type
: Indicate the model type ("t0", "t1", "SR").save_image
: Set toTrue
if you want to save the output images; provide the output path inimage_out_path
.do_patches
: Enable if processing images in patches; adjusttile
andtile_overlap
as needed, default values are 320 and 128.y_channel_PSNR
: Enable if need to calculate PSNR/SSIM in Y Channel, default is set to False.
This pipeline processes a video by extracting frames and running a pre-trained model for tasks like desnowing:
-
Edit
video_to_frames.py
:- Set the
video_path
to your input video file. - Set the
output_folder
to save extracted frames.
- Set the
-
Run the script:
python video_to_frames.py
-
Edit
inference_no_ground_truth.py
:- Set paths for
config
,model_path
,data_dir
(extracted frames), andimage_out_path
(output frames).
- Set paths for
-
Run the script:
python inference_no_ground_truth.py
- To get the parameter count, MAC, and inference speed use this command:
python basicsr/models/archs/turtle_arch.py
We invite the community to contribute to extending TURTLE to other low-level vision tasks. Below is a list of specific areas where contributions could be highly valuable if the models are open-sourced. If you have other suggestions or requests, please feel free to open an issue.
- Training TURTLE for Synthetic Super-Resolution Tasks
- Bicubic (BI) Degradation: Train on REDS, Vimeo90K and evaluate on REDSS4, Vimeo90K-T.
- Blur-Downsampling (BD) Degradation: Train on Vimeo90K and evaluate on Vimeo90K-T, Vid4, UDM10.
For more information on dataset selection and data preparation, please refer to Section 4.3 in this paper.
This codebase borrows from the following BasicSR and ShiftNet repositories.
If you find our work useful, please consider citing our paper in your research.
@inproceedings{ghasemabadilearning,
title={Learning Truncated Causal History Model for Video Restoration},
author={Ghasemabadi, Amirhosein and Janjua, Muhammad Kamran and Salameh, Mohammad and Niu, Di},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems}
}