Harsh Rangwani*, Sumukh K Aithal*, Mayank Mishra, R. Venkatesh Babu
This is the official PyTorch implementation for our NeurIPS'22 paper: Escaping Saddle Points for Effective Generalization on Class-Imbalanced Data [OpenReview
] [arXiv
]
UPDATE : We integrated our method with GLMC (CVPR 2023). Our method leads to ~2% gain over GLMC (SotA) 😄 [link
].
Real-world datasets exhibit imbalances of varying types and degrees. Several techniques based on re-weighting and margin adjustment of loss are often used to enhance the performance of neural networks, particularly on minority classes. In this work, we analyze the class-imbalanced learning problem by examining the loss landscape of neural networks trained with re-weighting and margin based techniques. Specifically, we examine the spectral density of Hessian of class-wise loss, through which we observe that the network weights converges to a saddle point in the loss landscapes of minority classes. Following this observation, we also find that optimization methods designed to escape from saddle points can be effectively used to improve generalization on minority classes. We further theoretically and empirically demonstrate that Sharpness-Aware Minimization (SAM), a recent technique that encourages convergence to a flat minima, can be effectively used to escape saddle points for minority classes. Using SAM results in a 6.2% increase in accuracy on the minority classes over the state-of-the-art Vector Scaling Loss, leading to an overall average increase of 4% across imbalanced datasets.
TLDR: Tail class loss landscape converges to a saddle point in imbalanced datasets and SAM can effectively escape from these solutions.
-
- pytorch 1.9.1
- torchvision 0.10.1
- wandb 0.12.2
- timm 0.5.5
- prettytable 2.2.0
- scikit-learn
- matplotlib
- tensorboardX
git clone https://github.com/val-iisc/Saddle-LongTail.git
cd Saddle-LongTail
pip install -r requirements.txt
We use Weights and Biases (wandb) to track our experiments and results. To track your experiments with wandb, create a new project with your account. The project
and entity
arguments in wandb.init
must be changed accordingly in .py file for each experiment. To disable wandb tracking, the log_results
flag can be removed.
-
The datasets used in the repository can be downloaded by following instructions from the following links:
The CIFAR datasets are automatically downloaded to the
data/
folder if it is not available.
Sample command to train CIFAR-10 LT dataset with CE+DRW+SAM.
python cifar_train_sam.py --gpu 0 --imb_type exp --imb_factor 0.01 --loss_type LDAM --train_rule DRW --rho 0.8 --rho_schedule none --log_results --dataset cifar10 --seed 0
Sample command to train ImageNet-LT dataset with LDAM+DRW+SAM.
python imnet_train_sam.py --gpu 0 --imb_type exp --imb_factor 0.01 --data_path <Path-to-Dataset> --loss_type LDAM --train_rule DRW --dataset imagenet -b 256 --epochs 90 --arch resnet50 --cos_lr --rho_schedule step --lr 0.2 --seed 0 --rho_steps 0.05 0.1 0.5 0.5 --log_results --wd 2e-4 --margin 0.3
All the commands to reproduce the experiments are available in run.sh
We show results on CIFAR-10 LT, CIFAR-100 LT, ImageNet-LT and iNaturalist-18 dataset. Complete results is available in the paper.
Dataset | Method | Accuracy | Checkpoints |
---|---|---|---|
CIFAR-10 LT (IF=100) | LDAM+DRW+SAM | 81.9 | ckpt |
CE+DRW+SAM | 80.6 | ckpt | |
CIFAR-100 LT (IF=100) | LDAM+DRW+SAM | 45.4 | ckpt |
CE+DRW+SAM | 44.6 | ckpt | |
ImageNet-LT | LDAM+DRW+SAM | 53.1 | ckpt |
CE+DRW+SAM | 47.1 | ckpt | |
iNaturalist-18 | LDAM+DRW+SAM | 70.1 | ckpt |
CE+DRW+SAM | 65.3 | ckpt |
We also run our method with the latest SOTA method GLMC (CVPR 2023) and demonstrate that the proposed method can further improve performance. As previously conjectured in our work, we apply SAM on the re-weighting loss of GLMC to avoid saddle points. Note that we use a GLMC-2023/run.sh
.
The sample command to run GLMC includes specifying additional param --rho 0.05, example command below:
python GLMC-2023/main.py --dataset cifar10 -a resnet34 --num_classes 10 --imbanlance_rate 0.02 --beta 0.5 --lr 0.01 --epochs 200 -b 64 --momentum 0.9 --weight_decay 5e-3 --resample_weighting 0.0 --label_weighting 1.2 --contrast_weight 1 --rho 0.05
Result
CIFAR-10 | CIFAR-10 | CIFAR-100 | CIFAR-100 | |
---|---|---|---|---|
50 | 100 | 50 | 100 | |
GLMC | 89.81 | 87.55 | 62.49 | 57.63 |
GLMC + SAM | 91.56 | 89.18 | 65.28 | 59.01 |
We also release the code to compute the spectral density and analyse the loss landscape of the trained models.
Sample command below:
python hessian_analysis.py --gpu 0 --seed 1 --exp_str sample --resume <checkpoint_path> --dataloader_hess train --log_results
On running this command, the Eigen Spectral density of per-class loss is computed and the class-wise spectral density is plotted along with the maximum eigenvalue and the trace of the Hessian.
Generally, all python scripts in the project take the following flags
-
-a
: Architecture of the backbone.(resnet32|resnet50)
-
--dataset
: Dataset(cifar10|cifar100)
-
--imb_type
: Imbalance Type(Exp|Step)
. -
--imb_factor
: Imbalance Factor (Ratio of samples in the minority class to majority class). Default: 0.01 -
--epochs
: Number of Epochs to be trained for. Default 200. -
--loss_type
: Loss Type(CE|LDAM|VS)
-
--gpu
: GPU id to use. -
--rho
:$\rho$ value in SAM (Applicable to SAM runs).
Our implementation is based on the LDAM and VS-Loss. We use the PyTorch implementation of SAM from https://github.com/davda54/sam. We refer to PyHessian for computation of the Eigen Spectral density and the loss landscape analysis. We thank the authors for releasing their source-code publicly.
The implementation of GLMC+SAM is based on GLMC codebase. We thank the authors for publicly releasing the code.
If you find our paper or codebase useful, please consider citing us as:
@inproceedings{
rangwani2022escaping,
title={Escaping Saddle Points for Effective Generalization on Class-Imbalanced Data},
author={Harsh Rangwani and Sumukh K Aithal and Mayank Mishra and Venkatesh Babu Radhakrishnan},
booktitle={Advances in Neural Information Processing Systems},
editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho},
year={2022},
url={https://openreview.net/forum?id=9DYKrsFSU2}
}