Skip to content

Official PyTorch code for "Vector Quantization Prompting for Continual Learning (NeurIPS2024)".

License

Notifications You must be signed in to change notification settings

jiaolifengmi/VQ-Prompt

Repository files navigation

VQ-Prompt

Official PyTorch code for "Vector Quantization Prompting for Continual Learning (NeurIPS2024)".

Abstract

Continual learning requires to overcome catastrophic forgetting when training a single model on a sequence of tasks. Recent top-performing approaches are prompt-based methods that utilize a set of learnable parameters (\ie, prompts) to encode task knowledge, from which appropriate ones are selected to guide the fixed pre-trained model in generating features tailored to a certain task. However, existing methods rely on predicting prompt identities for prompt selection, where the identity prediction process cannot be optimized with task loss. This limitation leads to sub-optimal prompt selection and a failure to adapt pre-trained features for the specific task. Previous efforts have tried to address this by directly generating prompts from input queries instead of selecting from a set of candidates, which, however, results in continuous prompts that lack sufficient abstraction for effective task knowledge representation. To address these challenges, we propose VQ-Prompt, a prompt-based continual learning method that incorporates Vector Quantization (VQ) into end-to-end training of a set of discrete prompts. Without the need for storing past data, VQ-Prompt outperforms state-of-the-art continual learning methods across a variety of benchmarks under the challenging class-incremental setting.

Requirements

  • python=3.8.18
  • torch=2.0.0+cu118
  • torchvision=0.15.1+cu118
  • timm=0.9.12
  • scikit-learn=1.3.2
  • numpy
  • pyaml
  • pillow
  • opencv-python
  • pandas
  • openpyxl (write results to a xlsx file)

Datasets

  • Create a folder data/

Checkpoints

  • Create a folder pretrained/

Training

Run the following commands under the project root directory. The scripts are set up for 1 GPUs.

sh experiments/cifar-100.sh
sh experiments/imagenet-r_all.sh
sh experiments/cub-200.sh

Results

Results will be saved in a folder named outputs/.

Note on setting

Our method has not been tested for other settings such as domain-incremental continual learning.

Reference Codes

[1] CODA-Prompt

[2] HiDe-Prompt

Citation

If you find this repository is useful, please cite the following reference.

@article{jiao2024vector,
  title={Vector Quantization Prompting for Continual Learning},
  author={Jiao, Li and Lai, Qiuxia and Li, Yu and Xu, Qiang},
  journal={NeurIPS},
  year={2024}
}

About

Official PyTorch code for "Vector Quantization Prompting for Continual Learning (NeurIPS2024)".

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published