对于细粒度分类问题,一般的网络不能取得准确率较高的结果,而论文《 See Better Before Looking Closer: Weakly Supervised Data Augmentation Network for Fine-Grained Visual Classification》提出基于弱监督的数据增强网络,作者提出了基于注意力的双线性池化(BAP)、注意力正则化(AP)、注意力引导数据增强(Drop和Crop)、最后预测阶段进行目标定位与图像精修(Refinement),本Repo为基于Paddle2.2框架的复现。
参考repo: https://github.com/wvinzh/WS_DAN_PyTorch
在此非常感谢wvinzh
等人贡献的WS_DAN_PyTorch,提高了本repo复现论文的效率。
AI Studio体验教程: 运行一下 (已挂载相关数据集,省去修改路径步骤,模型测试、训练直接运行!)
Github复现地址: 点击查看
论文中采用的数据集均为细粒度分类问题的典型代表,包括鸟、飞机、汽车、狗等类别,相关数据集的下载及复现精度如下:(点击数据集链接可下载对应数据集,推荐直接到百度网盘 提取码:1234下载,这样不用修改子文件夹名称)
注:其中CUB-200-2011鸟类数据集需将epoch提升至100才能达到原论文精度(原论文80epoch),其余数据集均可在80epoch内达到相应精度
Dataset | Object | Category | Training | Testing | ACC(复现) | ACC(原论文) |
---|---|---|---|---|---|---|
CUB-200-2011 | Bird | 200 | 5994 | 5794 | 89.40/89.23(80epoch) | 89.4 |
fgvc-aircraft | Aircraft | 100 | 6667 | 3333 | 94.03 | 93.0 |
Stanford-Cars | Car | 196 | 8144 | 8041 | 94.88 | 94.5 |
Stanford-Dogs | Dogs | 120 | 12000 | 8580 | (未要求) | 92.2 |
PaddlePaddle == 2.2.0
如下为数据集目录,Fine-grained为总的数据集文件夹,以下列出为模型训练、预测时需要用到的文件,请下载并按照如下名称命名相关文件夹,若从百度网盘给出的链接下载,则可省去修改文夹名
Fine-grained
├── CUB_200_2011
├── images
├── images.txt
├── image_class_labels.txt
├── train_test_split.txt
├── Car
├── cars_test
├── cars_train
├── cars_test_annos_withlabels.mat
├── devkit
├── cars_train_annos.mat
├── fgvc-aircraft-2013b
├── data
├── variants.txt
├── images_variant_trainval.txt
├── images_variant_test.txt
若您想从头训练,需要准备Inceptionv3预训练模型权重,该模型用于提取特征图(Feature Map, FM)和注意力图(Attention Map, AM),由于Github无法上传100MB以上文件,该权重参数141MB,所以您需要手动下载Inceptionv3提取码:1234预训练模型参数文件,并保存到models文件夹下,然后即可开始训练。
WS-DAN-Paddle-Victory8858
├── README.md # 用户指南
├── datasets # 各种数据集定义读取文件夹
├── CUBTINY # 一小部分鸟类数据集
├── *.jpg # 鸟类图片(共5张)
├── *.txt # 训练、预测标签
├── __init__.py # 读取数据集函数
├── aircraft_dataset.py # 飞机类数据集定义
├── bird_dataset.py # 鸟类数据集定义
├── bird_tiny_dataset.py # 一小部分鸟类数据集定义(用于TIPC)
├── car_dataset.py # 车类数据集定义
├── models # 模型相关文件
├── bap.py # BAP模型
├── inception.py # Inceptionv3模型
├── wsdan.py # WS-DAN模型
├── InceptionV3_pretrained.pdparams # Inceptionv3模型权重(需要您下载,见3.3中链接)
├── test_tipc # TICP
├── # 具体见TICP部分文档
├── FGVC # 模型参数保存与训练日志
├── aircraft # 飞机类模型参数以及训练日志
├── *.pdparams # 模型网络权重
├── *.log # 训练日志
├── brid # 鸟类模型参数以及训练日志
├── *.pdparams # 模型网络权重
├── *.log # 训练日志
├── car # 车类模型参数以及训练日志
├── *.pdparams # 模型网络权重
├── *.log # 训练日志
├── dataset_path_config.py # 数据集路径配置文件(您需要修改)
├── train.py # 模型训练
├── train.sh # 模型训练启动脚本
├── val.py # 模型测试
├── val.sh # 模型测试启动脚本
├── predicted.py # 单张图片预测
├── export_model.py # 模型动转静
├── infer.py # 利用静态模型进行推理
├── utils.py # 工具链
└── imgs # Markdown 图片资源
在开始训练前,假如您已经按上述操作准备好了相关数据集,并按照3.2中的文件名命名,那么最后一步就是修改dataset_path_config.py
文件中的数据集路径,您需要修改的内容如下:
bird_dataset_path = "E:/dataset/Fine-grained/CUB_200_2011" # 修改为您的路径
car_dataset_path = "E:/dataset/Fine-grained/Car" # 修改为您的路径
aircraft_dataset_path = "E:/dataset/Fine-grained/fgvc-aircraft-2013b/data" # 修改为您的路径
修改好后,马上即可开始训练、测试。
共有3种数据集需要训练,每个数据集都需要训练一个模型,在训练开始前,您可修改train.sh中的dataset
变量来指定想要训练的模型,如下所示:
bash train.sh
# 或
python train.py --dataset car --epochs 80 --batch_size 12 --num_workers 0
您只需运行val.sh文件即可,修改dataset
变量即可指定测试何种数据集的精度
bash val.sh
# 或
python val.py --dataset car --batch-size 6 --num-workers 0
如下所示:
您需运行predict.py文件,如想更换预测的数据集,运行predict.sh文件修改dataset
变量即可
bash predict.sh
# 或
python predict.py --dataset aircraft # Options: bird, car, aircraft
预测结果如下所示:
export_model.py为模型导出文件,可以将训练好的*.pdparams
模型进行动转静,运行方式如下:
其中model
的值可为bird、car、aircraft、bird_tiny,save_dir
为转换模型的保存路径
python export_model.py --model "bird" --save-dir "output"
infer.py为模型推理文件,输入一张图片利用静态模型进行推理,得到图片类别和其对应概率,本文件在TIPC中会用到,若输入CUBTINY数据集中的3.jpg
,其推理结果如下图所示,class_id
代表类别,prob
代表对应概率值
您可通过如下方式运行:
python infer.py --img-path datasets/CUBTINY/3.jpg # 可更换为CUBTINY中存在的图片(1-5.jpg)
TIPC命令如下,执行完后根目录内会生成output,log文件夹,其中output为保存的模型,log为运行结果的日志,详见TIPC文档:
bash test_tipc/test_train_inference_python.sh test_tipc/configs/WS-DAN/train_infer_python.txt lite_train_lite_infer