📎 [Arxiv], 🚀 [NVlabs/MaskLLM (Official)]
This repo contains a minimal re-implementation of the paper "MaskLLM: Learnable Semi-structured Sparsity for Large Language Models" for vision tasks.
- ViTs on ImageNet-1k (Classification)
- DiTs on ImageNet-1k (Generation)
- Multi-modal LLMs
- Other vision tasks (e.g., object detection, segmentation)
- TensorRT examples
To enable MaskLLM on your model, please replace the nn.Linear
with sparsity.maskllm.MaskedLinear
. This can be easily achieved with the following code snippet:
import sparsity
model = sparsity.utils.replace_linear_with_(model, sparsity.maskllm.MaskedLinear, exclude=[model.head], N=2, M=4, hard=False)
print(model)
Method | Sparsity Pattern | Weight Update | Mask Prior | Top-1 Acc. (%) |
---|---|---|---|---|
ViT-B/16 (in1k) | Dense | - | - | 79.15 |
Magnitude | 2:4 | - | - | 65.92 |
Wanda | 2:4 | - | - | 63.28 |
SparseGPT | 2:4 | ✔️ | - | 71.52 |
SparseGPT w/o Update | 2:4 | - | - | 59.72 |
MaskLLM-4V (1 Epoch) | 2:4 | - | SparseGPT | 76.23 |
MaskLLM-4V (1 Epoch) | 2:4 | - | Magnitude | 76.18 |
MaskLLM-4V (20 Epochs) | 2:4 | - | SparseGPT | 79.46 |
MaskLLM-4V (20 Epochs) | 2:4 | - | Magnitude | 79.28 |
Note: MaskLLM learns a separate mask with frozen network parameters for sparsification. For ViT-B/16, we can find a lossless mask through end-to-end training
Please prepare the ImageNet-1k dataset under ./data/imagenet
directory. The directory structure should be as follows:
data
├── imagenet
│ ├── train
│ │ ├── n01440764
│ │ ├── n01443537
│ │ ├── n01484850
│ │ ├── n01491361
│ └── val
│ │ ├── n01440764
│ │ ├── n01443537
│ │ ├── n01484850
│ │ ├── n01491361
1.1. MaskLLM for Vision Transformers
We trained MaskLLM on ViT-B/16 with 4x24GB GPUs, requiring 17G memory on each GPU with a batch size of 128.
We first generate prior masks using oneshot pruning. This prior mask will hugely accelerate the convergence speed of the MaskLLM. replace the --pruner
argument with magnitude
, wanda
, or sparsegpt
to generate different prior masks.
python oneshot_pruning_timm.py --model vit_base_patch16_224.augreg_in1k --pruner sparsegpt --save-model output/pruned/vit_base_patch16_224.augreg_in1k.sparsegpt24.pt
We took training hyperparameters from this timm issue. By default, we train the model with EMA for 20 epochs. For one-epoch training, please disable EMA like this script.
bash scripts/maskllm_vit_base_patch16_224.augreg_in1k.sparsegpt24.sh
python timm_validate.py --model vit_base_patch16_224 --checkpoint output/maskllm_vit_base_patch16_224.augreg_in1k.sparsegpt24/MaskLLM-4V/model_best.pth.tar --sparsity-mode maskllm
{
"model": "vit_base_patch16_224",
"top1": 79.456,
"top1_err": 20.544,
"top5": 94.548,
"top5_err": 5.452,
"param_count": 213.97,
"img_size": 224,
"crop_pct": 0.9,
"interpolation": "bicubic"
}
To perform MaskLLM on other models or prior types, please change the --model
and --checkpoint
arguments.
Detailed Instructions
# Eval
python timm_validate.py --model vit_base_patch16_224.augreg_in1k --pretrained
{
"model": "vit_base_patch16_224.augreg_in1k",
"top1": 79.158,
"top1_err": 20.842,
"top5": 94.088,
"top5_err": 5.912,
"param_count": 86.57,
"img_size": 224,
"crop_pct": 0.9,
"interpolation": "bicubic"
}
Detailed Instructions
# Magnitude pruning
python oneshot_pruning_timm.py --model vit_base_patch16_224.augreg_in1k --pruner magnitude --save-model output/pruned/vit_base_patch16_224.augreg_in1k.magnitude24.pt
# Eval
python timm_validate.py --model vit_base_patch16_224 --checkpoint output/pruned/vit_base_patch16_224.augreg_in1k.magnitude24.pt --sparsity-mode sparse
{
"model": "vit_base_patch16_224",
"top1": 65.92,
"top1_err": 34.08,
"top5": 86.058,
"top5_err": 13.942,
"param_count": 86.57,
"img_size": 224,
"crop_pct": 0.9,
"interpolation": "bicubic"
}
Detailed Instructions
# Wanda pruning
python oneshot_pruning_timm.py --model vit_base_patch16_224.augreg_in1k --pruner wanda --save-model output/pruned/vit_base_patch16_224.augreg_in1k.wanda24.pt
# Eval
python timm_validate.py --model vit_base_patch16_224 --checkpoint output/pruned/vit_base_patch16_224.augreg_in1k.wanda24.pt --sparsity-mode sparse
{
"model": "vit_base_patch16_224",
"top1": 63.282,
"top1_err": 36.718,
"top5": 85.574,
"top5_err": 14.426,
"param_count": 86.57,
"img_size": 224,
"crop_pct": 0.9,
"interpolation": "bicubic"
}
Detailed Instructions
# SparseGPT pruning
python oneshot_pruning_timm.py --model vit_base_patch16_224.augreg_in1k --pruner sparsegpt --save-model output/pruned/vit_base_patch16_224.augreg_in1k.sparsegpt24.pt
# Eval
python timm_validate.py --model vit_base_patch16_224 --checkpoint output/pruned/vit_base_patch16_224.augreg_in1k.sparsegpt24.pt --sparsity-mode sparse
{
"model": "vit_base_patch16_224",
"top1": 59.728,
"top1_err": 40.272,
"top5": 82.326,
"top5_err": 17.674,
"param_count": 86.57,
"img_size": 224,
"crop_pct": 0.9,
"interpolation": "bicubic"
}
# SparseGPT pruning with weight update
python oneshot_pruning_timm.py --model vit_base_patch16_224.augreg_in1k --pruner sparsegpt --save-model output/pruned/vit_base_patch16_224.augreg_in1k.sparsegpt24_updated.pt --enable-update
# Eval
python timm_validate.py --model vit_base_patch16_224 --checkpoint output/pruned/vit_base_patch16_224.augreg_in1k.sparsegpt24_updated.pt --sparsity-mode sparse
{
"model": "vit_base_patch16_224",
"top1": 71.52,
"top1_err": 28.48,
"top5": 90.246,
"top5_err": 9.754,
"param_count": 86.57,
"img_size": 224,
"crop_pct": 0.9,
"interpolation": "bicubic"
}
This part is still in progress. Please stay tuned.
Please prepare the ImageNet-1k dataset under ./data/imagenet
directory. The directory structure should be as follows:
data
├── imagenet
│ ├── train
│ │ ├── n01440764
│ │ ├── n01443537
│ │ ├── n01484850
│ │ ├── n01491361
│ └── val
│ │ ├── n01440764
│ │ ├── n01443537
│ │ ├── n01484850
│ │ ├── n01491361
2.1 MaskLLM for Diffusion Transformers
TODO
python sample.py --model DiT-XL/2
python oneshot_pruning_dit.py --model DiT-XL/2 --pruner magnitude
python oneshot_pruning_dit.py --model DiT-XL/2 --pruner wanda
python oneshot_pruning_dit.py --model DiT-XL/2 --pruner sparsegpt
python oneshot_pruning_dit.py --model DiT-XL/2 --pruner sparsegpt --enable-update
This project is based on the following repositories:
If you find this repository helpful, please consider citing the following paper.
@article{fang2024maskllm,
title={Maskllm: Learnable semi-structured sparsity for large language models},
author={Fang, Gongfan and Yin, Hongxu and Muralidharan, Saurav and Heinrich, Greg and Pool, Jeff and Kautz, Jan and Molchanov, Pavlo and Wang, Xinchao},
journal={arXiv preprint arXiv:2409.17481},
year={2024}
}