Skip to content

Pytorch reimplementation of "A Unified View of Masked Image Modeling".

License

Notifications You must be signed in to change notification settings

bwconrad/masked-distillation

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Masked Distillation Pretraining

PyTorch reimplementation of MaskDistill from "A Unified View of Masked Image Modeling". This repo provides many different masked distillation pretraining configurations explored in the original work.

Requirements

  • Python 3.8+
  • pip install -r requirements

Usage

To pretrain a ViT-b/16 network following MaskDistil's default configuration run:

python train.py --accelerator gpu --devices 1 --precision 16  --data.root path/to/data
--max_epochs 300 --data.batch_size 256 --model.student_model vit_base_patch16 
--model.teacher_model clip_vit_large_patch14 --model.head_model fc --model.lr 2.5e-4
--data.mask_ratio 0.4
  • Run python train.py --help for descriptions of all options.

  • --model.student_name can be one of:

    • vit_tiny_patch16, vit_small_patch16, vit_base_patch16, vit_large_patch16, vit_huge_patch14
    • swin_tiny_patch4, swin_small_patch4, swin_base_patch4, swin_large_patch4
  • --model.teacher_name can be one of:

    • clip_vit_base_patch32, clip_vit_base_patch16, clip_vit_large_patch14
    • openclip_vit_base_patch32, openclip_vit_large_patch14, openclip_vit_giant_patch14, openclip_vit_huge_patch14
    • pixel (predict masked patches' pixel values)
  • --model.head can be one of fc, vit.

Using a Pretrained Model

Student weights can be extracted from a pretraining checkpoint file by running:

python scripts/extract_student_weights.py -c path/to/checkpoint/file

You can then initialize a ViT model with these weights with the following:

import torch
from timm.models.vision_transformer import VisionTransformer

weights = torch.load("path/to/weights/file")

# Assuming weights are for a ViT-b/16 model
model = VisionTransformer(
    patch_size=16,
    embed_dim=768,
    depth=12,
    num_heads=12,
)
model.load_state_dict(weights)
  • Note: VisionTransformer arguments should match the those used during pretraining (e.g. ViT-b/16, ViT-l/16, etc.).

If you want to keep the pretrained relative position embeddings and/or the Layer Scale gamma parameters you can run:

python scripts/extract_student_weights.py -c path/to/checkpoint/file --keep_rel_pos_bias 
--keep_layer_scale_gamma 

To load the relative position embeddings you will need to initialize a BEiT model with the following:

import torch
from timm.models.beit import Beit

weights = torch.load("path/to/weights/file")

model = Beit(
    patch_size=16,
    embed_dim=768,
    depth=12,
    num_heads=12,
    use_rel_pos_bias=True,
    init_values=0.1,
)
model.load_state_dict(weights)

Citation

@article{peng2022unified,
  title={A unified view of masked image modeling},
  author={Peng, Zhiliang and Dong, Li and Bao, Hangbo and Ye, Qixiang and Wei, Furu},
  journal={arXiv preprint arXiv:2210.10615},
  year={2022}
}