Arpit Aggarwal
In this project, different CNN Architectures like VGG-16, VGG-19, and ResNet-50, with and without CBAM module used for adding Spatial and Channel Attention to the feature maps, were used for the task of Dog-Cat image classification. The input to the CNN networks was a (224 x 224 x 3) image and the number of classes were 2, where '0' was for a cat and '1' was for a dog. The CNN architectures were implemented in PyTorch and the loss function was Cross Entropy Loss. The hyperparameters to be tuned were: Number of epochs(e), Learning Rate(lr), momentum(m), weight decay(wd) and batch size(bs).
The data for the task of Dog-Cat image classification can be downloaded from: https://drive.google.com/drive/folders/1EdVqRCT1NSYT6Ge-SvAIu7R5i9Og2tiO?usp=sharing. The dataset has been divided into three sets: Training data, Validation data and Testing data. The analysis of different CNN architectures(with CBAM module and without) for Dog-Cat image classification was done on comparing the Training Accuracy and Validation Accuracy values.
The results after using different CNN architectures are given below:
- VGG-16(without CBAM, pre-trained on ImageNet dataset)
Training Accuracy = 99.27% and Validation Accuracy = 96.73% (e = 50, lr = 0.005, m = 0.9, bs = 32, wd = 0.001)
- VGG-19(without CBAM, pre-trained on ImageNet dataset)
Training Accuracy = 99.13% and Validation Accuracy = 97.25% (e = 50, lr = 0.005, m = 0.9, bs = 32, wd = 5e-4)
- ResNet-50(without CBAM, pre-trained on ImageNet dataset)
Training Accuracy = 99.43% and Validation Accuracy = 98.43% (e = 50, lr = 0.005, m = 0.9, bs = 32, wd = 5e-4)
- VGG-16(with CBAM, pre-trained on ImageNet dataset)
Training Accuracy = 99.66% and Validation Accuracy = 98.21% (e = 30, lr = 1e-3, m = 0.9, bs = 32, wd = 5e-4)
- VGG-19(with CBAM, pre-trained on ImageNet dataset)
Training Accuracy = 99.66% and Validation Accuracy = 98.95% (e = 30, lr = 1e-3, m = 0.9, bs = 32, wd = 0.001)
To run the jupyter notebooks, use Python 3. Standard libraries like Numpy and PyTorch are used.
The following links were helpful for this project: