Skip to content

Codebase for benchmarking robustness of Self-Supervised Learning across diverse downstream tasks

License

Notifications You must be signed in to change notification settings

layer6ai-labs/ssl-robustness

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SSL-Robustness

This is the codebase for Benchmarking Robust Self-Supervised Learning Across Diverse Downstream Tasks, accepted at ICML 2024 Workshop FM-Wild.

Installation & Usage

Install python 10

conda create -n robust-ssl python=3.10
conda activate robust-ssl

Install PyTorch 1.13.0

conda install pytorch=1.13.0 torchvision pytorch-cuda=[YOUR_CUDA_VERSION_HERE] -c pytorch -c nvidia

Install the rest of requirements

conda install --file requirements.txt
conda install -c conda-forge timm
pip install hydra-core --upgrade
pip install -U albumentations
pip install -U datasets
pip install git+https://github.com/fra31/auto-attack
pip install pytorch-lightning lightning-bolts wandb

To better organize the paramerters, we use hydra. For now, we have three groups of configuration. 1. action, 2. task, and 3.ssl_models. You can update the hyperparameters for each action and task inside their corresponding yaml file.

To run main method you should pass which action and which task you want to do, and which ssl encoder you would like to use.

python main.py +action=$ACTION training +task=$TASK +ssl_model=$ENCODER task.dataset=$DATASET

You can run this code for different actions. Here are different actions and what they do:

downstream_training: trains a downstream model (segmenter, classifier or depth estimator) on top of the encoder and reports the validation accuracy

adversarial_attack: attacks both the encoder and the downstream model and reports the adversarial accuracy. For segmentation for now we attack only the encoder and report adversarial mean accuracy and mean IoU.

For task you can selection one of three tasks: classification, segmentation, depth_estimation. For ssl_model, you can select any of the files inside conf/ssl_model/, just remove .yaml suffix. Finally, for the task.dataset, you can select any dataset list here.

Checkpoints

Please download the checkpoints folder from this drive and put it under tmp/ to get started.

Checkpoints/SSL models that we currently support:

Pretrained SSL Model Finetuned? Backbone Config Source
CIFAR10 SimCLR No ResNet-50 simclr_cifar10_resnet50 Pytorch SimCLR
CIFAR10 SimCLR No ResNet-18 simclr_cifar10_resnet18_solo solo-learn
CIFAR10 SimCLR DeACL ResNet-18 simclr_cifar10_resnet18_DeACL Finetuned with DeACL from solo-learn checkpoint
CIFAR10 SimSiam No ResNet-18 simsiam_cifar10_resnet18_solo solo-learn
ImageNet SimCLRv1 No ResNet-50 simclr_v1_imagenet_resnet50 How Well Do Self-Supervised Models Transfer? repo + SimCLRv1 tf to torch converter
ImageNet SimCLRv2 No ResNet-50 simclr_v2_imagenet_resnet50 How Well Do Self-Supervised Models Transfer? repo + SimCLRv2 tf to torch converter
ImageNet DINOv1 No ViT-small (patch size 8) dino_v1_vits8 dino
ImageNet DINOv1 No ViT-small (patch size 16) dino_v1_vits16 dino
ImageNet DINOv1 No ViT-tiny (patch size 16) dino_v1_vitt16_ours Trained by SprintML
ImageNet DINOv1 Adversarial training ViT-small (patch size 16) dino_v1_vits16_adv_ours Trained by SprintML
DINO DINOv2 No ViT-small (patch size 16) dino_v2_vits14_distilled dinov2

Adding a SSL MODEL

To add a SSL MODEL: create a new config for the model under the config folder /conf/ssl_model/ with the required parameters. For example, to add dino_vitb16 from the dino repo, create a file "dino_v1_vitb16.yaml" under /conf/ssl_model/ as following:

name: "dino_v1_vitb16" # The name of the Pretrained SSL model
ssl_model: "dinov1" # Type of SSL model, see models/ssl_models/utils.py for a full list
encoder_arch: "vitb" 
proj_output_dim: 65536
use_bn_in_head: False
feature_dim: 768 
num_heads: 12 
patch_size: 16 
image_size: 224
pretrained_url:
  repo: "facebookresearch/dino:main"
  name: "dino_vitb16"
ckpt:
  ckpts_folder: null 
  final_ckpt: "tmp/checkpoints/imagenet/dino/official/dino_vitb16.ckpt" 
  encoder_name: "backbone"
  projector_name: "head"
  non_standard_naming: False
  state_dict_key: "teacher"

There are two ways to specify the pretrained checkpoint paths:

  • TO examine only one checkpoint (usually the final checkpoint) of the trained SSL model, then pass that full file path as in final_ckpt under ckpt.
  • If you would like to examine more than one checkpoints of an encoder (such as a training history), pass the folder's name as ckpts_folder under ckpt. The code will load each file under the folder. (Make sure there is no other file rather than checkpoints to be tested in the folder)

Datasets

The code supports the following datasets: cifar100, imagenet, foodseg103, cifar100, mnist, fashion-mnist, flowers102, food101, stl10, ADE20K, CityScapes, Pascal VOC 2012, NYU Depth v2. For the latter ones you need to take the actions below to download/parse the datasets.

ADE20k

We use ADE20k from the MIT Scene Parsing Benchmark

To load ADE20k:

  1. first download data from here (link→downloads→Scene Parsing→Data:[train/val(922MB)])
  2. insert the zip file into tmp/data/ade20k/
  3. unzip it
  4. run python3 data/ade20k_to_hgdataset.py --path_to_ade20k tmp/data/ade20k/ADEChallengeData2016 --path_to_hgdataset tmp/data/ade20k_hg

CityScapes

We use CityScapes from the Semantic Understanding of Urban Street Scenes

To load CityScapes:

  1. first download data, gt_Fine_trainvaltest.zip and leftImg8bit_trainvalset.zip from here
  2. unzip both of them to one directory: tmp/data/cityscapes so it contains two folders: gtFine and leftImg8bit.
  3. run python3 data/cityscapes_to_hg_dataset.py

Pascal VOC 2012

We use Pascal VOC 2012 from the Visual Object Classes Challenge 2012 (VOC2012)

To load Pascal VOC 2012:

  1. first download data, wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar .
  2. mkdir tmp/data/voc && tar -xvf VOCtrainval_11-May-2012.tar -C tmp/data/voc
  3. run python3 data/pascal_voc_to_hg_dataset.py

NYU Depth V2

We use HuggingFace NYU Depth V2 dataset. It loads it automatically but takes several hours to load and process the images.

Code formatting

This repo uses black for python formatting.

About

Codebase for benchmarking robustness of Self-Supervised Learning across diverse downstream tasks

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages