A PyTorch implementation of ACNet based on TCSVT 2023 paper ACNet: Approaching-and-Centralizing Network for Zero-Shot Sketch-Based Image Retrieval.
conda install pytorch=1.10.0 torchvision cudatoolkit=11.3 -c pytorch
pip install timm
conda install pytorch-metric-learning -c metric-learning -c pytorch
Sketchy Extended and TU-Berlin Extended datasets are used in this repo, you could download these datasets from official websites, or download them from Google Drive. The data directory structure is shown as follows:
├──sketchy
├── train
├── sketch
├── airplane
├── n02691156_58-1.jpg
└── ...
...
├── photo
same structure as sketch
├── val
same structure as train
...
├──tuberlin
same structure as sketchy
...
python train.py --data_name tuberlin
optional arguments:
--data_root Datasets root path [default value is '/data']
--data_name Dataset name [default value is 'sketchy'](choices=['sketchy', 'tuberlin'])
--backbone_type Backbone type [default value is 'resnet50'](choices=['resnet50', 'vgg16'])
--emb_dim Embedding dim [default value is 512]
--batch_size Number of images in each mini-batch [default value is 64]
--epochs Number of epochs over the model to train [default value is 10]
--warmup Number of warmups over the extractor to train [default value is 1]
--save_root Result saved root path [default value is 'result']
python test.py --num 8
optional arguments:
--data_root Datasets root path [default value is '/data']
--query_name Query image name [default value is '/data/sketchy/val/sketch/cow/n01887787_591-14.jpg']
--data_base Queried database [default value is 'result/sketchy_resnet50_512_vectors.pth']
--num Retrieval number [default value is 5]
--save_root Result saved root path [default value is 'result']
The models are trained on one NVIDIA GTX TITAN (12G) GPU. Adam
is used to optimize the model, lr
is 1e-5
for backbone, 1e-3
for generator and 1e-4
for discriminator. all the hyper-parameters are the default values.
Backbone | Dim | Sketchy Extended | TU-Berlin Extended | Download | ||||||
---|---|---|---|---|---|---|---|---|---|---|
mAP@200 | mAP@all | P@100 | P@200 | mAP@200 | mAP@all | P@100 | P@200 | |||
VGG16 | 64 | 32.6 | 38.0 | 48.7 | 44.7 | 39.8 | 37.1 | 50.6 | 48.0 | MEGA |
VGG16 | 512 | 38.3 | 42.2 | 53.3 | 49.3 | 47.2 | 43.9 | 58.1 | 55.3 | MEGA |
VGG16 | 4096 | 40.0 | 43.2 | 54.6 | 50.8 | 51.7 | 47.9 | 62.3 | 59.3 | MEGA |
ResNet50 | 64 | 43.0 | 46.0 | 56.8 | 52.7 | 47.5 | 44.9 | 57.2 | 54.9 | MEGA |
ResNet50 | 512 | 51.7 | 55.9 | 64.3 | 60.8 | 57.7 | 57.7 | 65.8 | 64.4 | MEGA |
ResNet50 | 4096 | 51.1 | 55.7 | 63.8 | 60.0 | 57.3 | 58.6 | 64.6 | 63.5 | MEGA |
If you find ACNet helpful, please consider citing:
@article{ren2023acnet,
title={ACNet: Approaching-and-Centralizing Network for Zero-Shot Sketch-Based Image Retrieval},
author={Ren, Hao and Zheng, Ziqiang and Wu, Yang and Lu, Hong and Yang, Yang and Shan, Ying and Yeung, Sai-Kit},
journal={IEEE Transactions on Circuits and Systems for Video Technology},
year={2023},
publisher={IEEE}
}