Skip to content

GauravBh1010tt/DPViT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DPViT

Pytorch implementation of DPViT (accepted at Neurips'23)

paper poster slides

Mitigating the Effect of Incidental Correlations on Part-based Learning
Gaurav Bhatt*, Deepayan Das, Leonid Sigal, Vineeth N Balasubramanian


Outline

Installation

The code has been tested with the following environment:

git clone https://github.com/GauravBh1010tt/DPViT.git
cd DPViT
conda env create --name dpvit --file=environment.yml
source activate dpvit

Data Preparation

We extract the weakly-supervised masks using RemBg package. Since, RemBG is not optimized for batch inference, we modify thier code, and is present in DPViT/data_utils/rembg.

cd data_utils
python rembg/nbg_replace.py --dataset=miniIM --img_dir='path_to_train_folder' --batch_size=20

The image and it corresponding mask is saved a a single image. This is to speedup dataloading process over slurm by minimizing the number of file I/O calls.

miniImageNet

Use the mini-imagenet-tools to create imagenet dataset. Please note that all datasets should have the format similar to ImagNet and should look like this:

|-- miniimagenet
|   |-- train
|   |   |-- n908761
|   |   |-- n453897
|   |   |-- ...

ImageNet-9

Download the ImageNet-9 dataset from ImageNet-9. The structure of files looks like this:

|-- in9
|   |--train
|   |   |-- 00_dog
|   |   |-- 01_bird
|   |   |   |-- ...
|   |-- bg_challenge
|   |   |-- mixed_same
|   |   |   |-- 00_dog
|   |   |   |-- 01_bird
|   |   |   |-- ...
|   |   |-- mixed_rand
|   |   |-- ...

Training DPViT

On local machine:

bash scripts/run_local.sh

Update the hyper-parameters in the run_local.sh file

On Slurm:

bash scripts/run_slurm.sh

You need to update the cluster-specific configuration in run_slurm.sh file.

GPU requirement

We train DPViT on 4 A100 GPUs with 40 GB of VRAM each. Try setting the batch size according to your spefications.

Inference

By default the visualizations are saved inside the exp folder: <exp_name>/visualization_epoch<#>. The inference can be done on given images inside the img_viz folder using the following command:

python eval/eval_dpvit.py --ckp_path="path to saved checkpoint" --eval=0 --viz=1 --image_path='img_viz'

Evaluation

Few-shot

python eval/eval_dpvit.py --ckp_path="path to saved checkpoint" --eval=1 --num_shots=5

ImageNet-9

python eval/imagenet_cls.py --pretrained_weights="saved model" --data_path "data/in9" --partition "bg_challenge/mixed_same/val" --num_classes 9

Choose one of the following partitions from ImageNet-9 : mixed_same, mixed_rand, ...

Citation

If you find this repo useful, please cite:

@inproceedings{bhatt2023mitigating,
  title={Mitigating the Effect of Incidental Correlations on Part-based Learning},
  author={Bhatt, Gaurav and Das, Deepayan and Sigal, Leonid and Balasubramanian, Vineeth N},
  booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
  year={2023}
}