A PyTorch implementation of CGD based on the paper Combination of Multiple Global Descriptors for Image Retrieval.
conda install pytorch torchvision cudatoolkit=10.0 -c pytorch
- thop
pip install thop
CARS196, CUB200-2011, Standard Online Products and In-shop Clothes are used in this repo.
You should download these datasets by yourself, and extract them into ${data_path}
directory, make sure the dir names are
car
, cub
, sop
and isc
. Then run data_utils.py
to preprocess them.
python train.py --feature_dim 512 --gd_config SM
optional arguments:
--data_path datasets path [default value is '/home/data']
--data_name dataset name [default value is 'car'](choices=['car', 'cub', 'sop', 'isc'])
--crop_type crop data or not, it only works for car or cub dataset [default value is 'uncropped'](choices=['uncropped', 'cropped'])
--backbone_type backbone network type [default value is 'resnet50'](choices=['resnet50', 'resnext50'])
--gd_config global descriptors config [default value is 'SG'](choices=['S', 'M', 'G', 'SM', 'MS', 'SG', 'GS', 'MG', 'GM', 'SMG', 'MSG', 'GSM'])
--feature_dim feature dim [default value is 1536]
--smoothing smoothing value for label smoothing [default value is 0.1]
--temperature temperature scaling used in softmax cross-entropy loss [default value is 0.5]
--margin margin of m for triplet loss [default value is 0.1]
--recalls selected recall [default value is '1,2,4,8']
--batch_size train batch size [default value is 128]
--num_epochs train epoch number [default value is 20]
python test.py --retrieval_num 10
optional arguments:
--query_img_name query image name [default value is '/home/data/car/uncropped/008055.jpg']
--data_base queried database [default value is 'car_uncropped_resnet50_SG_1536_0.1_0.5_0.1_128_data_base.pth']
--retrieval_num retrieval number [default value is 8]
The models are trained on one NVIDIA Tesla V100 (32G) GPU with 20 epochs, the learning rate is decayed by 10 on 12th and 16th epoch.
Backbone | CARS196 | CUB200 | SOP | In-shop |
---|---|---|---|---|
ResNet50 | 26.86M | 10.64G | 26.86M | 10.64G | 49.85M | 10.69G | 34.85M | 10.66G |
ResNeXt50 | 26.33M | 10.84G | 26.33M | 10.84G | 49.32M | 10.89G | 34.32M | 10.86G |
Backbone | R@1 | R@2 | R@4 | R@8 | Download Link |
---|---|---|---|---|---|
ResNet50(SG) | 86.4% | 92.4% | 92.1% | 96.1% | 95.6% | 97.8% | 97.5% | 98.7% | r3sn | sf5s |
ResNeXt50(SG) | 86.4% | 91.7% | 92.0% | 95.4% | 95.4% | 97.3% | 97.6% | 98.6% | dsdx | fh72 |
Backbone | R@1 | R@2 | R@4 | R@8 | Download Link |
---|---|---|---|---|---|
ResNet50(MG) | 66.0% | 73.9% | 76.4% | 83.1% | 84.8% | 89.6% | 90.7% | 94.0% | 2cfi | pi4q |
ResNeXt50(MG) | 66.1% | 73.7% | 76.3% | 82.6% | 84.0% | 89.0% | 90.1% | 93.3% | nm9h | 6mkf |
Backbone | R@1 | R@10 | R@100 | R@1000 | Download Link |
---|---|---|---|---|---|
ResNet50(SG) | 79.3% | 90.6% | 95.8% | 98.6% | qgsn |
ResNeXt50(SG) | 71.0% | 85.3% | 93.5% | 97.9% | uexd |
Backbone | R@1 | R@10 | R@20 | R@30 | R@40 | R@50 | Download Link |
---|---|---|---|---|---|---|---|
ResNet50(GS) | 83.6% | 95.7% | 97.1% | 97.7% | 98.1% | 98.4% | 8jmp |
ResNeXt50(GS) | 85.0% | 96.1% | 97.3% | 97.9% | 98.2% | 98.4% | wdq5 |