FireClassification is a deep learning Framework written in Python and used for Image Classification task, running on top of the machine learning platform Pytorch.
Read the source code as documentation.
首先git clone本项目
- 下载fashion mnist数据集的四个压缩包放到./data目录下,运行
python scripts/make_fashionmnist.py
自动提取图片并划分类别、验证集 - 执行python train.py 训练
- 执行python evaluate.py 测试(在config设置训练好的模型路径)
- 迁移学习,下载对应模型的预训练模型,把路径填入config.py中
- 调整不同的模型、尺寸、优化器等等
依次修改fire/model.py相应代码即可。
- 文件夹形式
- csv标签形式
- 其它自定义形式需手动修改代码
- Resnet系列,Densenet系列,VGGnet系列等所有pretrained-models.pytorch支持的网络
- Mobilenetv2,Mbilenetv3,ShuffleNetV2
- EfficientNet
- Swin Transformer
- ConvNeXt
- TIMM库所有模型
- Adam
- SGD
- AdaBelief
- AdamW
- ReduceLROnPlateau
- StepLR
- MultiStepLR
- SGDR
- 交叉熵
- Focalloss
- Metric(acc, F1)
- 训练日志保存
- 交叉验证
- 梯度裁剪
- earlystop
- weightdecay
- 按文件夹设置分类标签、读取csv标签
- 冻结/解冻 除最后的全连接层的特征层
- labelsmooth
- 2023.9 [v1.1] 优化代码,删掉一些不用的功能,替换一些依赖库为自己实现,修复bug简化代码,修改存储路径
- 2022.7 [v1.0] (根据这半年打比赛经验,增加一些东西,删除一些几乎不用的东西。) 增加convnext、swin transformer、半精度训练,删除mobileformer,删除日志、tensorboard(习惯用文档记录),优化readme
- 2021.8 [v0.9] 增加micronet和测试结果,增加rk3399测速
- 2021.8 [v0.8] 增加mobileformer,加入fashion mnist数据集使用demo,方便测试各种模型,同时加入部分网络的训练结果
- 完善Readme
- 增加使用文档
- 彻底分离用户自定义部分的代码