Skip to content

[ICML 2024] SimPro: A Simple Probabilistic Framework Towards Realistic Long-Tailed Semi-Supervised Learning

Notifications You must be signed in to change notification settings

LeapLabTHU/SimPro

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SimPro

This repository contains the Pytorch implementation of the ICML 2024 paper SimPro: A Simple Probabilistic Framework Towards Realistic Long-Tailed Semi-Supervised Learning.

Authors: Chaoqun Du, Yizeng Han, Gao Huang.

arXiv

Introduction

fig1

The general idea of SimPro addressing the ReaLTSSL problem. (a) Current methods typically rely on predefined or assumed class distribution patterns for unlabeled data, limiting their applicability. (b) In contrast, our SimPro embraces a more realistic scenario by introducing a simple and elegant framework that operates effectively without making any assumptions about the distribution of unlabeled data. This paradigm shift allows for greater flexibility and applicability in diverse ReaLTSSL scenarios.

fig1

The SimPro Framework Overview. This framework distinctively separates the conditional and marginal (class) distributions. In the E-step (top), pseudo-labels are generated using the current parameters $\theta$ and $\pi$. In the subsequent M-step (bottom), these pseudo-labels, along with the ground-truth labels, are utilized to compute the Cross-Entropy loss, facilitating the optimization of network parameters $\theta$ via gradient descent. Concurrently, the marginal distribution parameter $\pi$ is recalculated using a closed-form solution based on the generated pseudo-labels.

Performance

fig1

fig1

Get Started

Requirements

  • python 3.9
  • numpy 1.25.2
  • Pillow 9.2.0
  • Requests 2.28.1
  • scipy 1.9.3
  • torch 1.12.1
  • torchvision 0.13.1

The above environment is recommended, but not necessary. You can also use other versions of the packages.

Datasets

Follow the instructions from here to download small_imagenet_1k_32 & small_imagenet_1k_64 datasets and prepare small_imagenet_127_32 & small_imagenet_127_64 datasets.

The default data path is set to /home/data/ and the data structure should be as follows:

/home/data/
├── cifar-10
├── cifar-100
├── stl-10
├── small_imagenet_127_32
├── small_imagenet_127_64
├── small_imagenet_1k_32
└── small_imagenet_1k_64

Training

By default, we use 1 RTX3090 GPU for CIFAR/STL10/ImageNet-127/1k ($32\times32$) datasets and 1 A100 GPU (40G) for ImageNet-127/1k ($64\times64$) dataset.

bash sh/${method}.sh ${dataset} ${exp_index} #comment for the experiment setting

bash sh/SimPro.sh cifar10 0          #gamma_l=150 consistent
bash sh/SimPro.sh cifar10 1          #gamma_l=150 uniform
bash sh/SimPro.sh cifar10 2          #gamma_l=150 reversed
bash sh/SimPro.sh cifar10 3          #gamma_l=150 middle
bash sh/SimPro.sh cifar10 4          #gamma_l=150 head-tail
bash sh/SimPro.sh cifar10 5          #gamma_l=100 consistent
bash sh/SimPro.sh cifar10 6          #gamma_l=100 uniform
bash sh/SimPro.sh cifar10 7          #gamma_l=100 reversed
bash sh/SimPro.sh cifar10 8          #gamma_l=100 middle
bash sh/SimPro.sh cifar10 9          #gamma_l=100 head-tail

bash sh/SimPro.sh cifar100 0         #gamma_l=20 consistent
bash sh/SimPro.sh cifar100 1         #gamma_l=20 uniform
bash sh/SimPro.sh cifar100 2         #gamma_l=20 reversed
bash sh/SimPro.sh cifar100 3         #gamma_l=20 middle
bash sh/SimPro.sh cifar100 4         #gamma_l=20 head-tail

bash sh/SimPro.sh stl10 0            #gamma_l=10
bash sh/SimPro.sh stl10 1            #gamma_l=20

bash sh/SimPro.sh smallimagenet    0 #gamma_t=1 ImageNet-127 32*32
bash sh/SimPro.sh smallimagenet    1 #gamma_t=1 ImageNet-127 64*64
bash sh/SimPro.sh smallimagenet_1k 0 #gamma_t=1 ImageNet-1k  32*32
bash sh/SimPro.sh smallimagenet_1k 1 #gamma_t=1 ImageNet-1k  64*64

bash sh/ACR_d.sh ...

ACR_d is a variant of ACR baseline without pre-defined anchor distributions, corresponding to ACR$^\dagger$ in the paper. We modified the computation of adjustment_l1 by replacing the dynamic tau (estimated from the anchor distributions) with a fixed tau1. Specifically, we only changed:

args.adjustment_l1 = compute_adjustment_by_py(args.py_con, tau, args)

to:

args.adjustment_l1 = compute_adjustment_by_py(args.py_con, args.tau1, args)

Citation

If you find this code useful, please consider citing our paper:

@inproceedings{du2024simpro,
    title={SimPro: A Simple Probabilistic Framework Towards Realistic Long-Tailed Semi-Supervised Learning},
    author={Chaoqun Du and Yizeng Han and Gao Huang},
    booktitle={Forty-first International Conference on Machine Learning},
    year={2024},
}

Contact

If you have any questions, please feel free to contact the authors. Chaoqun Du: dcq20@mails.tsinghua.edu.cn.

Acknowledgement

Our code is based on the ACR (Towards Realistic Long-Tailed Semi-Supervised Learning: Consistency Is All You Need).

About

[ICML 2024] SimPro: A Simple Probabilistic Framework Towards Realistic Long-Tailed Semi-Supervised Learning

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published