This repository is the official implementation of MetaGCD: Learning to Continually Learn in Generalized Category Discovery.
The code was tested on python3.6 pytorch1.4.0 and CUDA9.2.
We recommend using conda environment to setup all required dependencies:
conda env create -f environment.yml
conda activate MetaGCD
If you have any problem with the above command, you can also install them by pip install -r requirements.txt
.
We provide the training script for the following 3 datasets from the NCD benchmark: CIFAR10 , CIFAR100 and Tiny-ImageNet. To train the models in the paper, run the following commands:
python methods/contrastive_training/contrastive_learning_based_MAML.py --run_mode 'MetaTrain' --dataset_name <dataset>
Set paths to datasets, pre-trained models and desired log directories in config.py
To evaluate meta-trained models, run:
python methods/contrastive_training/contrastive_learning_based_MAML.py --run_mode 'MetaTest' --dataset_name <dataset>
If you find this codebase useful in your research, consider citing:
@inproceedings{
wu2023metagcd,
title={MetaGCD: Learning to Continually Learn in Generalized Category Discovery},
author={Yanan Wu and Zhixiang Chi and Yang Wang and and Songhe Feng},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
year={2023}
}