This repository contain implementation of BASNet in Tensorflow/Keras.
Note: We are looking for collaborators with good compute power to help us train full model.
If you interested please contact us at hamidriasat@gmail.com
- images: Model architecture images
- sample_data: Samples images for visualization
- weights: Model checkpoint directory
- basnet.py: Contains model implementation
- basnet_prediction.ipynb: Notebook to visualize model output
- basnet_training.ipynb: Notebook to train model
- dataloader.py: Dataloader to efficiently load data into memory
- loss.py: Implementation of BASNet hybrid loss function
- utils.py: Generic utility functions
Like paper, we have also used DUTS-TR dataset for training. It has 10,553 images. Commands to download data is written inside training notebook.
We have trained model on Google Colab Pro plus plan using A100
(40 GB) GPU.
We have trained model for 100 epochs (~120k iterations) with a batch size of 8.
It took almost 24 hours to train.
In paper author trained model for 400k iterations. That's why our results are
not as good as author results but are enough to demonstrate model learning abilities.
Training code can be found
in basnet_training.ipynb
Model output can be visualized using basnet_prediction.ipynb
.
Pretrained weights are available at this Google Drive link. Commands to download weights are present inside prediction notebook.
Dependies:
KerasCV 0.5.0
Tensorflow 2.12.0 (Can be any KerasCV compatible version)
Licensed under the MIT License