Skip to content

hizieun/MetaPseudoLabels

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Hits

Meta Pseudo Labels

This is an unofficial PyTorch implementation of Meta Pseudo Labels. The official Tensorflow implementation is here.

Results

CIFAR-10-4K SVHN-1K ImageNet-10%
Paper (w/ finetune) 96.11 ± 0.07 98.01 ± 0.07 73.89
This code (w/o finetune) 94.46 - -
This code (w/ finetune) WIP - -
Acc. curve link - -
  • I have experienced some difficulties while reproducing paper's result.
  • Please let me know where to modify my code! (issue)

Usage

Train the model by 4000 labeled data of CIFAR-10 dataset:

teacher, student 모델로 MPL 알고리즘으로 학습

python main.py --seed 5 --name cifar10-4K.5 --expand-labels --dataset cifar10 --num-classes 10 --num-labeled 4000 --total-steps 300000 --eval-step 1000 --randaug 2 16 --batch-size 128 --teacher_lr 0.05 --student_lr 0.05 --weight-decay 5e-4 --ema 0.995 --nesterov --mu 7 --label-smoothing 0.15 --temperature 0.7 --threshold 0.6 --lambda-u 8 --warmup-steps 5000 --uda-steps 5000 --student-wait-steps 3000 --teacher-dropout 0.2 --student-dropout 0.2 --amp

Train the model by 10000 labeled data of CIFAR-100 dataset by using DistributedDataParallel:

python -m torch.distributed.launch --nproc_per_node 4 main.py --seed 5 --name cifar100-10K.5 --dataset cifar100 --num-classes 100 --num-labeled 10000 --expand-labels --total-steps 300000 --eval-step 1000 --randaug 2 16 --batch-size 128 --teacher_lr 0.05 --student_lr 0.05 --weight-decay 5e-4 --ema 0.995 --nesterov --mu 7 --label-smoothing 0.15 --temperature 0.7 --threshold 0.6 --lambda-u 8 --warmup-steps 5000 --uda-steps 5000 --student-wait-steps 3000 --teacher-dropout 0.2 --student-dropout 0.2 --amp

Monitoring training progress

tensorboard --logdir results

FineTune

Student 모델을 train_loader(labeled dataset)으로 학습

python main.py --finetune  --data-path ../../../data/dogs-vs-cats --seed 5 --name dogs-vs-cats --dataset custom --num-classes 2 --finetune-epochs 125  --finetune-batch-size 64 --finetune-lr 1e-5  --finetune-weight-decay 0 --finetune-momentum 0 --amp

Evaluate

student 모델로 test_loader에 대해서 테스트

python main.py --data-path ../../../data/dogs-vs-cats --seed 5 --name dogs-vs-cats  --dataset custom --num-classes 2  --randaug 2 16 --batch-size 8  --amp --evaluate

Requirements

  • python 3.6+
  • torch 1.7+
  • torchvision 0.8+
  • tensorboard
  • wandb
  • numpy
  • tqdm
  • pandas