diff --git a/configs/det/det_repsvtr_db.yml b/configs/det/det_repsvtr_db.yml
new file mode 100644
index 0000000000..8c4768e714
--- /dev/null
+++ b/configs/det/det_repsvtr_db.yml
@@ -0,0 +1,169 @@
+Global:
+ debug: false
+ use_gpu: true
+ epoch_num: &epoch_num 500
+ log_smooth_window: 20
+ print_batch_step: 100
+ save_model_dir: ./output/det_repsvtr_db
+ save_epoch_step: 10
+ eval_batch_step:
+ - 0
+ - 1000
+ cal_metric_during_train: false
+ checkpoints:
+ pretrained_model:
+ save_inference_dir: null
+ use_visualdl: false
+ infer_img: doc/imgs_en/img_10.jpg
+ save_res_path: ./checkpoints/det_db/predicts_db.txt
+ distributed: true
+
+Architecture:
+ model_type: det
+ algorithm: DB
+ Transform: null
+ Backbone:
+ name: RepSVTR_det
+ Neck:
+ name: RSEFPN
+ out_channels: 96
+ shortcut: True
+ Head:
+ name: DBHead
+ k: 50
+
+Loss:
+ name: DBLoss
+ balance_loss: true
+ main_loss_type: DiceLoss
+ alpha: 5
+ beta: 10
+ ohem_ratio: 3
+
+Optimizer:
+ name: Adam
+ beta1: 0.9
+ beta2: 0.999
+ lr:
+ name: Cosine
+ learning_rate: 0.001 #(8*8c)
+ warmup_epoch: 2
+ regularizer:
+ name: L2
+ factor: 5.0e-05
+
+PostProcess:
+ name: DBPostProcess
+ thresh: 0.3
+ box_thresh: 0.6
+ max_candidates: 1000
+ unclip_ratio: 1.5
+
+Metric:
+ name: DetMetric
+ main_indicator: hmean
+
+Train:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/icdar2015/text_localization/
+ label_file_list:
+ - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
+ ratio_list: [1.0]
+ transforms:
+ - DecodeImage:
+ img_mode: BGR
+ channel_first: false
+ - DetLabelEncode: null
+ - CopyPaste: null
+ - IaaAugment:
+ augmenter_args:
+ - type: Fliplr
+ args:
+ p: 0.5
+ - type: Affine
+ args:
+ rotate:
+ - -10
+ - 10
+ - type: Resize
+ args:
+ size:
+ - 0.5
+ - 3
+ - EastRandomCropData:
+ size:
+ - 640
+ - 640
+ max_tries: 50
+ keep_ratio: true
+ - MakeBorderMap:
+ shrink_ratio: 0.4
+ thresh_min: 0.3
+ thresh_max: 0.7
+ total_epoch: *epoch_num
+ - MakeShrinkMap:
+ shrink_ratio: 0.4
+ min_text_size: 8
+ total_epoch: *epoch_num
+ - NormalizeImage:
+ scale: 1./255.
+ mean:
+ - 0.485
+ - 0.456
+ - 0.406
+ std:
+ - 0.229
+ - 0.224
+ - 0.225
+ order: hwc
+ - ToCHWImage: null
+ - KeepKeys:
+ keep_keys:
+ - image
+ - threshold_map
+ - threshold_mask
+ - shrink_map
+ - shrink_mask
+ loader:
+ shuffle: true
+ drop_last: false
+ batch_size_per_card: 8
+ num_workers: 8
+
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/icdar2015/text_localization/
+ label_file_list:
+ - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
+ transforms:
+ - DecodeImage:
+ img_mode: BGR
+ channel_first: false
+ - DetLabelEncode: null
+ - DetResizeForTest:
+ - NormalizeImage:
+ scale: 1./255.
+ mean:
+ - 0.485
+ - 0.456
+ - 0.406
+ std:
+ - 0.229
+ - 0.224
+ - 0.225
+ order: hwc
+ - ToCHWImage: null
+ - KeepKeys:
+ keep_keys:
+ - image
+ - shape
+ - polys
+ - ignore_tags
+ loader:
+ shuffle: false
+ drop_last: false
+ batch_size_per_card: 1
+ num_workers: 2
+profiler_options: null
diff --git a/configs/rec/SVTRv2/rec_repsvtr_gtc.yml b/configs/rec/SVTRv2/rec_repsvtr_gtc.yml
new file mode 100644
index 0000000000..6d1340ee6f
--- /dev/null
+++ b/configs/rec/SVTRv2/rec_repsvtr_gtc.yml
@@ -0,0 +1,134 @@
+Global:
+ debug: false
+ use_gpu: true
+ epoch_num: 200
+ log_smooth_window: 20
+ print_batch_step: 10
+ save_model_dir: ./output/rec_repsvtr_gtc
+ save_epoch_step: 10
+ eval_batch_step: [0, 1000]
+ cal_metric_during_train: False
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: false
+ infer_img: doc/imgs_words/ch/word_1.jpg
+ character_dict_path: ppocr/utils/ppocr_keys_v1.txt
+ max_text_length: &max_text_length 25
+ infer_mode: false
+ use_space_char: true
+ distributed: true
+ save_res_path: ./output/rec/predicts_repsvtr.txt
+
+Optimizer:
+ name: AdamW
+ beta1: 0.9
+ beta2: 0.999
+ epsilon: 1.e-8
+ weight_decay: 0.025
+ no_weight_decay_name: norm
+ one_dim_param_no_weight_decay: True
+ lr:
+ name: Cosine
+ learning_rate: 0.001 # 8gpus 192bs
+ warmup_epoch: 5
+
+
+Architecture:
+ model_type: rec
+ algorithm: SVTR_HGNet
+ Transform:
+ Backbone:
+ name: RepSVTR
+ Head:
+ name: MultiHead
+ head_list:
+ - CTCHead:
+ Neck:
+ name: svtr
+ dims: 256
+ depth: 2
+ hidden_dims: 256
+ kernel_size: [1, 3]
+ use_guide: True
+ Head:
+ fc_decay: 0.00001
+ - NRTRHead:
+ nrtr_dim: 384
+ max_text_length: *max_text_length
+ num_decoder_layers: 2
+
+Loss:
+ name: MultiLoss
+ loss_config_list:
+ - CTCLoss:
+ - NRTRLoss:
+
+PostProcess:
+ name: CTCLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+
+
+Train:
+ dataset:
+ name: MultiScaleDataSet
+ ds_width: false
+ data_dir: ./train_data/
+ ext_op_transform_idx: 1
+ label_file_list:
+ - ./train_data/train_list.txt
+ transforms:
+ - DecodeImage:
+ img_mode: BGR
+ channel_first: false
+ - RecAug:
+ - MultiLabelEncode:
+ gtc_encode: NRTRLabelEncode
+ - KeepKeys:
+ keep_keys:
+ - image
+ - label_ctc
+ - label_gtc
+ - length
+ - valid_ratio
+ sampler:
+ name: MultiScaleSampler
+ scales: [[320, 32], [320, 48], [320, 64]]
+ first_bs: &bs 192
+ fix_bs: false
+ divided_factor: [8, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: true
+ batch_size_per_card: *bs
+ drop_last: true
+ num_workers: 8
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data
+ label_file_list:
+ - ./train_data/val_list.txt
+ transforms:
+ - DecodeImage:
+ img_mode: BGR
+ channel_first: false
+ - MultiLabelEncode:
+ gtc_encode: NRTRLabelEncode
+ - RecResizeImg:
+ image_shape: [3, 48, 320]
+ - KeepKeys:
+ keep_keys:
+ - image
+ - label_ctc
+ - label_gtc
+ - length
+ - valid_ratio
+ loader:
+ shuffle: false
+ drop_last: false
+ batch_size_per_card: 128
+ num_workers: 4
diff --git a/configs/rec/SVTRv2/rec_svtrv2_gtc.yml b/configs/rec/SVTRv2/rec_svtrv2_gtc.yml
new file mode 100644
index 0000000000..d2ab95ac38
--- /dev/null
+++ b/configs/rec/SVTRv2/rec_svtrv2_gtc.yml
@@ -0,0 +1,145 @@
+Global:
+ debug: false
+ use_gpu: true
+ epoch_num: 200
+ log_smooth_window: 20
+ print_batch_step: 10
+ save_model_dir: ./output/rec_svtrv2_gtc
+ save_epoch_step: 10
+ eval_batch_step: [0, 1000]
+ cal_metric_during_train: False
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: false
+ infer_img: doc/imgs_words/ch/word_1.jpg
+ character_dict_path: ppocr/utils/ppocr_keys_v1.txt
+ max_text_length: &max_text_length 25
+ infer_mode: false
+ use_space_char: true
+ distributed: true
+ save_res_path: ./output/rec/predicts_svrtv2.txt
+
+
+Optimizer:
+ name: AdamW
+ beta1: 0.9
+ beta2: 0.999
+ epsilon: 1.e-8
+ weight_decay: 0.05
+ no_weight_decay_name: norm
+ one_dim_param_no_weight_decay: True
+ lr:
+ name: Cosine
+ learning_rate: 0.001 # 8gpus 192bs
+ warmup_epoch: 5
+
+
+Architecture:
+ model_type: rec
+ algorithm: SVTR_HGNet
+ Transform:
+ Backbone:
+ name: SVTRv2
+ use_pos_embed: False
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','Global','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[2, 1], [2, 1], [-1, -1]]
+ last_stage: False
+ use_pool: True
+ Head:
+ name: MultiHead
+ head_list:
+ - CTCHead:
+ Neck:
+ name: svtr
+ dims: 256
+ depth: 2
+ hidden_dims: 256
+ kernel_size: [1, 3]
+ use_guide: True
+ Head:
+ fc_decay: 0.00001
+ - NRTRHead:
+ nrtr_dim: 384
+ max_text_length: *max_text_length
+ num_decoder_layers: 2
+
+Loss:
+ name: MultiLoss
+ loss_config_list:
+ - CTCLoss:
+ - NRTRLoss:
+
+PostProcess:
+ name: CTCLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+
+
+
+Train:
+ dataset:
+ name: MultiScaleDataSet
+ ds_width: false
+ data_dir: ./train_data/
+ ext_op_transform_idx: 1
+ label_file_list:
+ - ./train_data/train_list.txt
+ transforms:
+ - DecodeImage:
+ img_mode: BGR
+ channel_first: false
+ - RecAug:
+ - MultiLabelEncode:
+ gtc_encode: NRTRLabelEncode
+ - KeepKeys:
+ keep_keys:
+ - image
+ - label_ctc
+ - label_gtc
+ - length
+ - valid_ratio
+ sampler:
+ name: MultiScaleSampler
+ scales: [[320, 32], [320, 48], [320, 64]]
+ first_bs: &bs 192
+ fix_bs: false
+ divided_factor: [8, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: true
+ batch_size_per_card: *bs
+ drop_last: true
+ num_workers: 8
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data
+ label_file_list:
+ - ./train_data/val_list.txt
+ transforms:
+ - DecodeImage:
+ img_mode: BGR
+ channel_first: false
+ - MultiLabelEncode:
+ gtc_encode: NRTRLabelEncode
+ - RecResizeImg:
+ image_shape: [3, 48, 320]
+ - KeepKeys:
+ keep_keys:
+ - image
+ - label_ctc
+ - label_gtc
+ - length
+ - valid_ratio
+ loader:
+ shuffle: false
+ drop_last: false
+ batch_size_per_card: 128
+ num_workers: 4
diff --git a/configs/rec/SVTRv2/rec_svtrv2_gtc_distill.yml b/configs/rec/SVTRv2/rec_svtrv2_gtc_distill.yml
new file mode 100644
index 0000000000..15d781fc22
--- /dev/null
+++ b/configs/rec/SVTRv2/rec_svtrv2_gtc_distill.yml
@@ -0,0 +1,208 @@
+Global:
+ debug: false
+ use_gpu: true
+ epoch_num: 100
+ log_smooth_window: 20
+ print_batch_step: 10
+ save_model_dir: ./output/rec_svtrv2_gtc_distill_lr00002/
+ save_epoch_step: 5
+ eval_batch_step:
+ - 0
+ - 1000
+ cal_metric_during_train: False
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: false
+ infer_img: doc/imgs_words/ch/word_1.jpg
+ character_dict_path: ppocr/utils/ppocr_keys_v1.txt
+ max_text_length: &max_text_length 25
+ infer_mode: false
+ use_space_char: true
+ distributed: true
+ save_res_path: ./output/rec/predicts_svtrv2_gtc_distill.txt
+Optimizer:
+ name: AdamW
+ beta1: 0.9
+ beta2: 0.99
+ epsilon: 1.e-8
+ weight_decay: 0.05
+ no_weight_decay_name: norm pos_embed patch_embed downsample
+ one_dim_param_no_weight_decay: True
+ lr:
+ name: Cosine
+ learning_rate: 0.0002 # 8gpus 192bs
+ warmup_epoch: 5
+Architecture:
+ model_type: rec
+ name: DistillationModel
+ algorithm: Distillation
+ Models:
+ Teacher:
+ pretrained: ./output/rec_svtrv2_gtc/best_accuracy
+ freeze_params: true
+ return_all_feats: true
+ model_type: rec
+ algorithm: SVTR_LCNet
+ Transform: null
+ Backbone:
+ name: SVTRv2
+ use_pos_embed: False
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','Global','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[2, 1], [2, 1], [-1, -1]]
+ last_stage: False
+ use_pool: True
+ Head:
+ name: MultiHead
+ head_list:
+ - CTCHead:
+ Neck:
+ name: svtr
+ dims: 256
+ depth: 2
+ hidden_dims: 256
+ kernel_size: [1, 3]
+ use_guide: True
+ Head:
+ fc_decay: 0.00001
+ - NRTRHead:
+ nrtr_dim: 384
+ num_decoder_layers: 2
+ max_text_length: *max_text_length
+ Student:
+ pretrained: ./output/rec_repsvtr_gtc/best_accuracy
+ freeze_params: false
+ return_all_feats: true
+ model_type: rec
+ algorithm: SVTR_LCNet
+ Transform: null
+ Backbone:
+ name: repvit_svtr
+ Head:
+ name: MultiHead
+ head_list:
+ - CTCHead:
+ Neck:
+ name: svtr
+ dims: 256
+ depth: 2
+ hidden_dims: 256
+ kernel_size: [1, 3]
+ use_guide: True
+ Head:
+ fc_decay: 0.00001
+ - NRTRHead:
+ nrtr_dim: 384
+ num_decoder_layers: 2
+ max_text_length: *max_text_length
+Loss:
+ name: CombinedLoss
+ loss_config_list:
+ - DistillationDKDLoss:
+ weight: 0.1
+ model_name_pairs:
+ - - Student
+ - Teacher
+ key: head_out
+ multi_head: true
+ alpha: 1.0
+ beta: 2.0
+ dis_head: gtc
+ name: dkd
+ - DistillationCTCLoss:
+ weight: 1.0
+ model_name_list:
+ - Student
+ key: head_out
+ multi_head: true
+ - DistillationNRTRLoss:
+ weight: 1.0
+ smoothing: false
+ model_name_list:
+ - Student
+ key: head_out
+ multi_head: true
+ - DistillCTCLogits:
+ weight: 1.0
+ reduction: mean
+ model_name_pairs:
+ - - Student
+ - Teacher
+ key: head_out
+PostProcess:
+ name: DistillationCTCLabelDecode
+ model_name:
+ - Student
+ key: head_out
+ multi_head: true
+Metric:
+ name: DistillationMetric
+ base_metric_name: RecMetric
+ main_indicator: acc
+ key: Student
+
+
+Train:
+ dataset:
+ name: MultiScaleDataSet
+ ds_width: false
+ data_dir: ./train_data/
+ ext_op_transform_idx: 1
+ label_file_list:
+ - ./train_data/train_list.txt
+ transforms:
+ - DecodeImage:
+ img_mode: BGR
+ channel_first: false
+ - RecAug:
+ - MultiLabelEncode:
+ gtc_encode: NRTRLabelEncode
+ - KeepKeys:
+ keep_keys:
+ - image
+ - label_ctc
+ - label_gtc
+ - length
+ - valid_ratio
+ sampler:
+ name: MultiScaleSampler
+ scales: [[320, 32], [320, 48], [320, 64]]
+ first_bs: &bs 192
+ fix_bs: false
+ divided_factor: [8, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: true
+ batch_size_per_card: *bs
+ drop_last: true
+ num_workers: 8
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data
+ label_file_list:
+ - ./train_data/val_list.txt
+ transforms:
+ - DecodeImage:
+ img_mode: BGR
+ channel_first: false
+ - MultiLabelEncode:
+ gtc_encode: NRTRLabelEncode
+ - RecResizeImg:
+ image_shape: [3, 48, 320]
+ - KeepKeys:
+ keep_keys:
+ - image
+ - label_ctc
+ - label_gtc
+ - length
+ - valid_ratio
+ loader:
+ shuffle: false
+ drop_last: false
+ batch_size_per_card: 128
+ num_workers: 4
diff --git a/doc/doc_ch/algorithm_rec_svtr.md b/doc/doc_ch/algorithm_rec_svtr.md
index 42a1a9a415..34881c1146 100644
--- a/doc/doc_ch/algorithm_rec_svtr.md
+++ b/doc/doc_ch/algorithm_rec_svtr.md
@@ -18,7 +18,7 @@
论文信息:
> [SVTR: Scene Text Recognition with a Single Visual Model](https://arxiv.org/abs/2205.00159)
-> Yongkun Du and Zhineng Chen and Caiyan Jia Xiaoting Yin and Tianlun Zheng and Chenxia Li and Yuning Du and Yu-Gang Jiang
+> Yongkun Du and Zhineng Chen and Caiyan Jia and Xiaoting Yin and Tianlun Zheng and Chenxia Li and Yuning Du and Yu-Gang Jiang
> IJCAI, 2022
场景文本识别旨在将自然图像中的文本转录为数字字符序列,从而传达对场景理解至关重要的高级语义。这项任务由于文本变形、字体、遮挡、杂乱背景等方面的变化具有一定的挑战性。先前的方法为提高识别精度做出了许多工作。然而文本识别器除了准确度外,还因为实际需求需要考虑推理速度等因素。
@@ -102,7 +102,7 @@ python3 tools/infer_rec.py -c ./rec_svtr_tiny_none_ctc_en_train/rec_svtr_tiny_6l
### 4.1 Python推理
-首先将训练得到best模型,转换成inference model。下面以基于`SVTR-T`,在英文数据集训练的模型为例([模型和配置文件下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) ),可以使用如下命令进行转换:
+首先将训练得到best模型,转换成inference model。下面以`SVTR-T`在英文数据集训练的模型为例([模型和配置文件下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) ),可以使用如下命令进行转换:
```shell
# 注意将pretrained_model的路径设置为本地路径。
diff --git a/doc/doc_ch/algorithm_rec_svtrv2.md b/doc/doc_ch/algorithm_rec_svtrv2.md
new file mode 100644
index 0000000000..a508b4f02c
--- /dev/null
+++ b/doc/doc_ch/algorithm_rec_svtrv2.md
@@ -0,0 +1,143 @@
+# 场景文本识别算法-SVTRv2
+
+- [1. 算法简介](#1)
+- [2. 环境配置](#2)
+- [3. 模型训练、评估、预测](#3)
+ - [3.1 训练](#3-1)
+ - [3.2 评估](#3-2)
+ - [3.3 预测](#3-3)
+- [4. 推理部署](#4)
+ - [4.1 Python推理](#4-1)
+ - [4.2 C++推理](#4-2)
+ - [4.3 Serving服务化部署](#4-3)
+ - [4.4 更多推理部署](#4-4)
+- [5. FAQ](#5)
+
+
+## 1. 算法简介
+
+### SVTRv2算法简介
+
+
+[PaddleOCR 算法模型挑战赛 - 赛题一:OCR 端到端识别任务](https://aistudio.baidu.com/competition/detail/1131/0/introduction)排行榜第一算法。主要思路:1、检测和识别模型的Backbone升级为RepSVTR;2、识别教师模型升级为SVTRv2,可识别长文本。
+
+
+
+## 2. 环境配置
+请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
+
+
+
+## 3. 模型训练、评估、预测
+
+
+### 3.1 模型训练
+
+
+训练命令:
+```shell
+#单卡训练(训练周期长,不建议)
+python3 tools/train.py -c configs/rec/SVTRv2/rec_repsvtr_gtc.yml
+
+#多卡训练,通过--gpus参数指定卡号
+# Rec 学生模型
+python -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/SVTRv2/rec_repsvtr_gtc.yml
+# Rec 教师模型
+python -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/SVTRv2/rec_svtrv2_gtc.yml
+# Rec 蒸馏训练
+python -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/SVTRv2/rec_svtrv2_gtc_distill.yml
+```
+
+
+### 3.2 评估
+
+
+```shell
+# 注意将pretrained_model的路径设置为本地路径。
+python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/SVTRv2/rec_repsvtr_gtc.yml -o Global.pretrained_model=output/rec_repsvtr_gtc/best_accuracy
+```
+
+
+### 3.3 预测
+
+使用如下命令进行单张图片预测:
+```shell
+# 注意将pretrained_model的路径设置为本地路径。
+python3 tools/infer_rec.py -c tools/eval.py -c configs/rec/SVTRv2/rec_repsvtr_gtc.yml -o Global.pretrained_model=output/rec_repsvtr_gtc/best_accuracy Global.infer_img='./doc/imgs_words_en/word_10.png'
+# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。
+```
+
+
+
+## 4. 推理部署
+
+
+### 4.1 Python推理
+首先将训练得到best模型,转换成inference model,以RepSVTR为例,可以使用如下命令进行转换:
+
+```shell
+# 注意将pretrained_model的路径设置为本地路径。
+python3 tools/export_model.py -c configs/rec/SVTRv2/rec_repsvtr_gtc.yml -o Global.pretrained_model=output/rec_repsvtr_gtc/best_accuracy Global.save_inference_dir=./inference/rec_repsvtr_infer
+```
+
+**注意:**
+- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否为所正确的字典文件。
+
+转换成功后,在目录下有三个文件:
+```
+./inference/rec_repsvtr_infer/
+ ├── inference.pdiparams # 识别inference模型的参数文件
+ ├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
+ └── inference.pdmodel # 识别inference模型的program文件
+```
+
+
+执行如下命令进行模型推理:
+
+```shell
+python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words_en/word_10.png' --rec_model_dir='./inference/rec_repsvtr_infer/'
+# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/imgs_words_en/'。
+```
+data:image/s3,"s3://crabby-images/cc608/cc6087ba20b9d140accfcb31aaebc6d98e9c9949" alt=""
+
+执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下:
+结果如下:
+```shell
+Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9999998807907104)
+```
+
+**注意**:
+
+- 如果您调整了训练时的输入分辨率,需要通过参数`rec_image_shape`设置为您需要的识别图像形状。
+- 在推理时需要设置参数`rec_char_dict_path`指定字典,如果您修改了字典,请修改该参数为您的字典文件。
+- 如果您修改了预处理方法,需修改`tools/infer/predict_rec.py`中SVTR的预处理为您的预处理方法。
+
+
+### 4.2 C++推理部署
+
+由于C++预处理后处理还未支持SVTRv2
+
+
+### 4.3 Serving服务化部署
+
+暂不支持
+
+
+### 4.4 更多推理部署
+
+暂不支持
+
+
+## 5. FAQ
+
+## 引用
+
+```bibtex
+@article{Du2022SVTR,
+ title = {SVTR: Scene Text Recognition with a Single Visual Model},
+ author = {Du, Yongkun and Chen, Zhineng and Jia, Caiyan and Yin, Xiaoting and Zheng, Tianlun and Li, Chenxia and Du, Yuning and Jiang, Yu-Gang},
+ booktitle = {IJCAI},
+ year = {2022},
+ url = {https://arxiv.org/abs/2205.00159}
+}
+```
diff --git a/ppocr/losses/rec_multi_loss.py b/ppocr/losses/rec_multi_loss.py
index c19febe535..74be385651 100644
--- a/ppocr/losses/rec_multi_loss.py
+++ b/ppocr/losses/rec_multi_loss.py
@@ -55,7 +55,7 @@ def forward(self, predicts, batch):
)
elif name == "NRTRLoss":
loss = (
- loss_func(predicts["nrtr"], batch[:1] + batch[2:])["loss"]
+ loss_func(predicts["gtc"], batch[:1] + batch[2:])["loss"]
* self.weight_2
)
else:
diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py
index ce80afd109..0b64992b20 100755
--- a/ppocr/modeling/backbones/__init__.py
+++ b/ppocr/modeling/backbones/__init__.py
@@ -25,6 +25,7 @@ def build_backbone(config, model_type):
from .rec_lcnetv3 import PPLCNetV3
from .rec_hgnet import PPHGNet_small
from .rec_vit import ViT
+ from .rec_repvit import RepSVTR_det
support_dict = [
"MobileNetV3",
@@ -34,6 +35,7 @@ def build_backbone(config, model_type):
"PPLCNet",
"PPLCNetV3",
"PPHGNet_small",
+ "RepSVTR_det",
]
if model_type == "table":
from .table_master_resnet import TableResNetExtra
@@ -59,6 +61,8 @@ def build_backbone(config, model_type):
from .rec_lcnetv3 import PPLCNetV3
from .rec_hgnet import PPHGNet_small
from .rec_vit_parseq import ViTParseQ
+ from .rec_repvit import RepSVTR
+ from .rec_svtrv2 import SVTRv2
support_dict = [
"MobileNetV1Enhance",
@@ -81,6 +85,8 @@ def build_backbone(config, model_type):
"PPHGNet_small",
"ViTParseQ",
"ViT",
+ "RepSVTR",
+ "SVTRv2",
]
elif model_type == "e2e":
from .e2e_resnet_vd_pg import ResNet
diff --git a/ppocr/modeling/backbones/rec_repvit.py b/ppocr/modeling/backbones/rec_repvit.py
new file mode 100644
index 0000000000..e983569c44
--- /dev/null
+++ b/ppocr/modeling/backbones/rec_repvit.py
@@ -0,0 +1,363 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This code is refer from:
+https://github.com/THU-MIG/RepViT
+"""
+
+import paddle.nn as nn
+import paddle
+from paddle.nn.initializer import TruncatedNormal, Constant, Normal
+
+trunc_normal_ = TruncatedNormal(std=0.02)
+normal_ = Normal
+zeros_ = Constant(value=0.0)
+ones_ = Constant(value=1.0)
+
+
+def _make_divisible(v, divisor, min_value=None):
+ """
+ This function is taken from the original tf repo.
+ It ensures that all layers have a channel number that is divisible by 8
+ It can be seen here:
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
+ :param v:
+ :param divisor:
+ :param min_value:
+ :return:
+ """
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+# from timm.models.layers import SqueezeExcite
+
+
+def make_divisible(v, divisor=8, min_value=None, round_limit=0.9):
+ min_value = min_value or divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_v < round_limit * v:
+ new_v += divisor
+ return new_v
+
+
+class SEModule(nn.Layer):
+ """SE Module as defined in original SE-Nets with a few additions
+ Additions include:
+ * divisor can be specified to keep channels % div == 0 (default: 8)
+ * reduction channels can be specified directly by arg (if rd_channels is set)
+ * reduction channels can be specified by float rd_ratio (default: 1/16)
+ * global max pooling can be added to the squeeze aggregation
+ * customizable activation, normalization, and gate layer
+ """
+
+ def __init__(
+ self,
+ channels,
+ rd_ratio=1.0 / 16,
+ rd_channels=None,
+ rd_divisor=8,
+ act_layer=nn.ReLU,
+ ):
+ super(SEModule, self).__init__()
+ if not rd_channels:
+ rd_channels = make_divisible(
+ channels * rd_ratio, rd_divisor, round_limit=0.0
+ )
+ self.fc1 = nn.Conv2D(channels, rd_channels, kernel_size=1, bias_attr=True)
+ self.act = act_layer()
+ self.fc2 = nn.Conv2D(rd_channels, channels, kernel_size=1, bias_attr=True)
+
+ def forward(self, x):
+ x_se = x.mean((2, 3), keepdim=True)
+ x_se = self.fc1(x_se)
+ x_se = self.act(x_se)
+ x_se = self.fc2(x_se)
+ return x * nn.functional.sigmoid(x_se)
+
+
+class Conv2D_BN(nn.Sequential):
+ def __init__(
+ self,
+ a,
+ b,
+ ks=1,
+ stride=1,
+ pad=0,
+ dilation=1,
+ groups=1,
+ bn_weight_init=1,
+ resolution=-10000,
+ ):
+ super().__init__()
+ self.add_sublayer(
+ "c", nn.Conv2D(a, b, ks, stride, pad, dilation, groups, bias_attr=False)
+ )
+ self.add_sublayer("bn", nn.BatchNorm2D(b))
+ if bn_weight_init == 1:
+ ones_(self.bn.weight)
+ else:
+ zeros_(self.bn.weight)
+ zeros_(self.bn.bias)
+
+ @paddle.no_grad()
+ def fuse(self):
+ c, bn = self.c, self.bn
+ w = bn.weight / (bn._variance + bn._epsilon) ** 0.5
+ w = c.weight * w[:, None, None, None]
+ b = bn.bias - bn._mean * bn.weight / (bn._variance + bn._epsilon) ** 0.5
+ m = nn.Conv2D(
+ w.shape[1] * self.c._groups,
+ w.shape[0],
+ w.shape[2:],
+ stride=self.c._stride,
+ padding=self.c._padding,
+ dilation=self.c._dilation,
+ groups=self.c._groups,
+ )
+ m.weight.set_value(w)
+ m.bias.set_value(b)
+ return m
+
+
+class Residual(nn.Layer):
+ def __init__(self, m, drop=0.0):
+ super().__init__()
+ self.m = m
+ self.drop = drop
+
+ def forward(self, x):
+ if self.training and self.drop > 0:
+ return (
+ x
+ + self.m(x)
+ * paddle.rand(x.size(0), 1, 1, 1)
+ .ge_(self.drop)
+ .div(1 - self.drop)
+ .detach()
+ )
+ else:
+ return x + self.m(x)
+
+ @paddle.no_grad()
+ def fuse(self):
+ if isinstance(self.m, Conv2D_BN):
+ m = self.m.fuse()
+ assert m._groups == m.in_channels
+ identity = paddle.ones([m.weight.shape[0], m.weight.shape[1], 1, 1])
+ identity = nn.functional.pad(identity, [1, 1, 1, 1])
+ m.weight += identity
+ return m
+ elif isinstance(self.m, nn.Conv2D):
+ m = self.m
+ assert m._groups != m.in_channels
+ identity = paddle.ones([m.weight.shape[0], m.weight.shape[1], 1, 1])
+ identity = nn.functional.pad(identity, [1, 1, 1, 1])
+ m.weight += identity
+ return m
+ else:
+ return self
+
+
+class RepVGGDW(nn.Layer):
+ def __init__(self, ed) -> None:
+ super().__init__()
+ self.conv = Conv2D_BN(ed, ed, 3, 1, 1, groups=ed)
+ self.conv1 = nn.Conv2D(ed, ed, 1, 1, 0, groups=ed)
+ self.dim = ed
+ self.bn = nn.BatchNorm2D(ed)
+
+ def forward(self, x):
+ return self.bn((self.conv(x) + self.conv1(x)) + x)
+
+ @paddle.no_grad()
+ def fuse(self):
+ conv = self.conv.fuse()
+ conv1 = self.conv1
+
+ conv_w = conv.weight
+ conv_b = conv.bias
+ conv1_w = conv1.weight
+ conv1_b = conv1.bias
+
+ conv1_w = nn.functional.pad(conv1_w, [1, 1, 1, 1])
+
+ identity = nn.functional.pad(
+ paddle.ones([conv1_w.shape[0], conv1_w.shape[1], 1, 1]), [1, 1, 1, 1]
+ )
+
+ final_conv_w = conv_w + conv1_w + identity
+ final_conv_b = conv_b + conv1_b
+
+ conv.weight.set_value(final_conv_w)
+ conv.bias.set_value(final_conv_b)
+
+ bn = self.bn
+ w = bn.weight / (bn._variance + bn._epsilon) ** 0.5
+ w = conv.weight * w[:, None, None, None]
+ b = (
+ bn.bias
+ + (conv.bias - bn._mean) * bn.weight / (bn._variance + bn._epsilon) ** 0.5
+ )
+ conv.weight.set_value(w)
+ conv.bias.set_value(b)
+ return conv
+
+
+class RepViTBlock(nn.Layer):
+ def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs):
+ super(RepViTBlock, self).__init__()
+
+ self.identity = stride == 1 and inp == oup
+ assert hidden_dim == 2 * inp
+
+ if stride != 1:
+ self.token_mixer = nn.Sequential(
+ Conv2D_BN(
+ inp, inp, kernel_size, stride, (kernel_size - 1) // 2, groups=inp
+ ),
+ SEModule(inp, 0.25) if use_se else nn.Identity(),
+ Conv2D_BN(inp, oup, ks=1, stride=1, pad=0),
+ )
+ self.channel_mixer = Residual(
+ nn.Sequential(
+ # pw
+ Conv2D_BN(oup, 2 * oup, 1, 1, 0),
+ nn.GELU() if use_hs else nn.GELU(),
+ # pw-linear
+ Conv2D_BN(2 * oup, oup, 1, 1, 0, bn_weight_init=0),
+ )
+ )
+ else:
+ assert self.identity
+ self.token_mixer = nn.Sequential(
+ RepVGGDW(inp),
+ SEModule(inp, 0.25) if use_se else nn.Identity(),
+ )
+ self.channel_mixer = Residual(
+ nn.Sequential(
+ # pw
+ Conv2D_BN(inp, hidden_dim, 1, 1, 0),
+ nn.GELU() if use_hs else nn.GELU(),
+ # pw-linear
+ Conv2D_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0),
+ )
+ )
+
+ def forward(self, x):
+ return self.channel_mixer(self.token_mixer(x))
+
+
+class RepViT(nn.Layer):
+ def __init__(self, cfgs, in_channels=3, out_indices=None):
+ super(RepViT, self).__init__()
+ # setting of inverted residual blocks
+ self.cfgs = cfgs
+
+ # building first layer
+ input_channel = self.cfgs[0][2]
+ patch_embed = nn.Sequential(
+ Conv2D_BN(in_channels, input_channel // 2, 3, 2, 1),
+ nn.GELU(),
+ Conv2D_BN(input_channel // 2, input_channel, 3, 2, 1),
+ )
+ layers = [patch_embed]
+ # building inverted residual blocks
+ block = RepViTBlock
+ for k, t, c, use_se, use_hs, s in self.cfgs:
+ output_channel = _make_divisible(c, 8)
+ exp_size = _make_divisible(input_channel * t, 8)
+ layers.append(
+ block(input_channel, exp_size, output_channel, k, s, use_se, use_hs)
+ )
+ input_channel = output_channel
+ self.features = nn.LayerList(layers)
+ self.out_indices = out_indices
+ if out_indices is not None:
+ self.out_channels = [self.cfgs[ids - 1][2] for ids in out_indices]
+ else:
+ self.out_channels = self.cfgs[-1][2]
+
+ def forward(self, x):
+ if self.out_indices is not None:
+ return self.forward_det(x)
+ return self.forward_rec(x)
+
+ def forward_det(self, x):
+ outs = []
+ for i, f in enumerate(self.features):
+ x = f(x)
+ if i in self.out_indices:
+ outs.append(x)
+ return outs
+
+ def forward_rec(self, x):
+ for f in self.features:
+ x = f(x)
+ h = x.shape[2]
+ x = nn.functional.avg_pool2d(x, [h, 2])
+ return x
+
+
+def RepSVTR(in_channels=3):
+ """
+ Constructs a MobileNetV3-Large model
+ """
+ # k, t, c, SE, HS, s
+ cfgs = [
+ [3, 2, 96, 1, 0, 1],
+ [3, 2, 96, 0, 0, 1],
+ [3, 2, 96, 0, 0, 1],
+ [3, 2, 192, 0, 1, (2, 1)],
+ [3, 2, 192, 1, 1, 1],
+ [3, 2, 192, 0, 1, 1],
+ [3, 2, 192, 1, 1, 1],
+ [3, 2, 192, 0, 1, 1],
+ [3, 2, 192, 1, 1, 1],
+ [3, 2, 192, 0, 1, 1],
+ [3, 2, 384, 0, 1, (2, 1)],
+ [3, 2, 384, 1, 1, 1],
+ [3, 2, 384, 0, 1, 1],
+ ]
+ return RepViT(cfgs, in_channels=in_channels)
+
+
+def RepSVTR_det(in_channels=3, out_indices=[2, 5, 10, 13]):
+ """
+ Constructs a MobileNetV3-Large model
+ """
+ # k, t, c, SE, HS, s
+ cfgs = [
+ [3, 2, 48, 1, 0, 1],
+ [3, 2, 48, 0, 0, 1],
+ [3, 2, 96, 0, 0, 2],
+ [3, 2, 96, 1, 0, 1],
+ [3, 2, 96, 0, 0, 1],
+ [3, 2, 192, 0, 1, 2],
+ [3, 2, 192, 1, 1, 1],
+ [3, 2, 192, 0, 1, 1],
+ [3, 2, 192, 1, 1, 1],
+ [3, 2, 192, 0, 1, 1],
+ [3, 2, 384, 0, 1, 2],
+ [3, 2, 384, 1, 1, 1],
+ [3, 2, 384, 0, 1, 1],
+ ]
+ return RepViT(cfgs, in_channels=in_channels, out_indices=out_indices)
diff --git a/ppocr/modeling/backbones/rec_svtrv2.py b/ppocr/modeling/backbones/rec_svtrv2.py
new file mode 100644
index 0000000000..31ce55a65a
--- /dev/null
+++ b/ppocr/modeling/backbones/rec_svtrv2.py
@@ -0,0 +1,575 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from paddle import ParamAttr
+from paddle.nn.initializer import KaimingNormal
+import numpy as np
+import paddle
+import paddle.nn as nn
+from paddle.nn.initializer import TruncatedNormal, Constant, Normal
+
+trunc_normal_ = TruncatedNormal(std=0.02)
+normal_ = Normal
+zeros_ = Constant(value=0.0)
+ones_ = Constant(value=1.0)
+
+
+def drop_path(x, drop_prob=0.0, training=False):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
+ """
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = paddle.to_tensor(1 - drop_prob, dtype=x.dtype)
+ shape = (paddle.shape(x)[0],) + (1,) * (x.ndim - 1)
+ random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
+ random_tensor = paddle.floor(random_tensor) # binarize
+ output = x.divide(keep_prob) * random_tensor
+ return output
+
+
+class DropPath(nn.Layer):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+
+class Identity(nn.Layer):
+ def __init__(self):
+ super(Identity, self).__init__()
+
+ def forward(self, input):
+ return input
+
+
+class Mlp(nn.Layer):
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ drop=0.0,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class ConvBNLayer(nn.Layer):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=0,
+ bias_attr=False,
+ groups=1,
+ act=nn.GELU,
+ ):
+ super().__init__()
+ self.conv = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
+ bias_attr=bias_attr,
+ )
+ self.norm = nn.BatchNorm2D(out_channels)
+ self.act = act()
+
+ def forward(self, inputs):
+ out = self.conv(inputs)
+ out = self.norm(out)
+ out = self.act(out)
+ return out
+
+
+class Attention(nn.Layer):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ self.dim = dim
+ self.head_dim = dim // num_heads
+ self.scale = qk_scale or self.head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ qkv = (
+ self.qkv(x)
+ .reshape((0, -1, 3, self.num_heads, self.head_dim))
+ .transpose((2, 0, 3, 1, 4))
+ )
+ q, k, v = qkv[0], qkv[1], qkv[2]
+
+ attn = (q.matmul(k.transpose((0, 1, 3, 2)))) * self.scale
+ attn = nn.functional.softmax(attn, axis=-1)
+ attn = self.attn_drop(attn)
+ x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((0, -1, self.dim))
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Layer):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ epsilon=1e-6,
+ ):
+ super().__init__()
+ self.norm1 = norm_layer(dim, epsilon=epsilon)
+ self.mixer = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
+ self.norm2 = norm_layer(dim, epsilon=epsilon)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp_ratio = mlp_ratio
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
+
+ def forward(self, x):
+ x = self.norm1(x + self.drop_path(self.mixer(x)))
+ x = self.norm2(x + self.drop_path(self.mlp(x)))
+ return x
+
+
+class ConvBlock(nn.Layer):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ epsilon=1e-6,
+ ):
+ super().__init__()
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.norm1 = norm_layer(dim, epsilon=epsilon)
+ self.mixer = nn.Conv2D(
+ dim,
+ dim,
+ 5,
+ 1,
+ 2,
+ groups=num_heads,
+ weight_attr=ParamAttr(initializer=KaimingNormal()),
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
+ self.norm2 = norm_layer(dim, epsilon=epsilon)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
+
+ def forward(self, x):
+ C, H, W = x.shape[1:]
+ x = x + self.drop_path(self.mixer(x))
+ x = self.norm1(x.flatten(2).transpose([0, 2, 1]))
+ x = self.norm2(x + self.drop_path(self.mlp(x)))
+ x = x.transpose([0, 2, 1]).reshape([0, C, H, W])
+ return x
+
+
+class FlattenTranspose(nn.Layer):
+ def forward(self, x):
+ return x.flatten(2).transpose([0, 2, 1])
+
+
+class SubSample2D(nn.Layer):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ stride=[2, 1],
+ ):
+ super().__init__()
+ self.conv = nn.Conv2D(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ weight_attr=ParamAttr(initializer=KaimingNormal()),
+ )
+ self.norm = nn.LayerNorm(out_channels)
+
+ def forward(self, x, sz):
+ # print(x.shape)
+ x = self.conv(x)
+ C, H, W = x.shape[1:]
+ x = self.norm(x.flatten(2).transpose([0, 2, 1]))
+ x = x.transpose([0, 2, 1]).reshape([0, C, H, W])
+ return x, [H, W]
+
+
+class SubSample1D(nn.Layer):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ stride=[2, 1],
+ ):
+ super().__init__()
+ self.conv = nn.Conv2D(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ weight_attr=ParamAttr(initializer=KaimingNormal()),
+ )
+ self.norm = nn.LayerNorm(out_channels)
+
+ def forward(self, x, sz):
+ C = x.shape[-1]
+ x = x.transpose([0, 2, 1]).reshape([0, C, sz[0], sz[1]])
+ x = self.conv(x)
+ C, H, W = x.shape[1:]
+ x = self.norm(x.flatten(2).transpose([0, 2, 1]))
+ return x, [H, W]
+
+
+class IdentitySize(nn.Layer):
+ def forward(self, x, sz):
+ return x, sz
+
+
+class SVTRStage(nn.Layer):
+ def __init__(
+ self,
+ dim=64,
+ out_dim=256,
+ depth=3,
+ mixer=["Local"] * 3,
+ sub_k=[2, 1],
+ num_heads=2,
+ mlp_ratio=4,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path=[0.1] * 3,
+ norm_layer=nn.LayerNorm,
+ act=nn.GELU,
+ eps=1e-6,
+ downsample=None,
+ **kwargs
+ ):
+ super().__init__()
+ self.dim = dim
+
+ conv_block_num = sum([1 if mix == "Conv" else 0 for mix in mixer])
+ blocks = []
+ for i in range(depth):
+ if mixer[i] == "Conv":
+ blocks.append(
+ ConvBlock(
+ dim=dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ drop=drop_rate,
+ act_layer=act,
+ drop_path=drop_path[i],
+ norm_layer=norm_layer,
+ epsilon=eps,
+ )
+ )
+ else:
+ blocks.append(
+ Block(
+ dim=dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=act,
+ attn_drop=attn_drop_rate,
+ drop_path=drop_path[i],
+ norm_layer=norm_layer,
+ epsilon=eps,
+ )
+ )
+ if i == conv_block_num - 1 and mixer[-1] != "Conv":
+ blocks.append(FlattenTranspose())
+ self.blocks = nn.Sequential(*blocks)
+ if downsample:
+ if mixer[-1] == "Conv":
+ self.downsample = SubSample2D(dim, out_dim, stride=sub_k)
+ elif mixer[-1] == "Global":
+ self.downsample = SubSample1D(dim, out_dim, stride=sub_k)
+ else:
+ self.downsample = IdentitySize()
+
+ def forward(self, x, sz):
+ x = self.blocks(x)
+ x, sz = self.downsample(x, sz)
+ return x, sz
+
+
+class ADDPosEmbed(nn.Layer):
+ def __init__(self, feat_max_size=[8, 32], embed_dim=768):
+ super().__init__()
+ pos_embed = paddle.zeros(
+ [1, feat_max_size[0] * feat_max_size[1], embed_dim], dtype=paddle.float32
+ )
+ trunc_normal_(pos_embed)
+ pos_embed = pos_embed.transpose([0, 2, 1]).reshape(
+ [1, embed_dim, feat_max_size[0], feat_max_size[1]]
+ )
+ self.pos_embed = self.create_parameter(
+ [1, embed_dim, feat_max_size[0], feat_max_size[1]]
+ )
+ self.add_parameter("pos_embed", self.pos_embed)
+ self.pos_embed.set_value(pos_embed)
+
+ def forward(self, x):
+ sz = x.shape[2:]
+ x = x + self.pos_embed[:, :, : sz[0], : sz[1]]
+ return x
+
+
+class POPatchEmbed(nn.Layer):
+ """Image to Patch Embedding"""
+
+ def __init__(
+ self,
+ in_channels=3,
+ feat_max_size=[8, 32],
+ embed_dim=768,
+ use_pos_embed=False,
+ flatten=False,
+ ):
+ super().__init__()
+ patch_embed = [
+ ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=embed_dim // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias_attr=None,
+ ),
+ ConvBNLayer(
+ in_channels=embed_dim // 2,
+ out_channels=embed_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias_attr=None,
+ ),
+ ]
+ if use_pos_embed:
+ patch_embed.append(ADDPosEmbed(feat_max_size, embed_dim))
+ if flatten:
+ patch_embed.append(FlattenTranspose())
+ self.patch_embed = nn.Sequential(*patch_embed)
+
+ def forward(self, x):
+ sz = x.shape[2:]
+ x = self.patch_embed(x)
+ return x, [sz[0] // 4, sz[1] // 4]
+
+
+class LastStage(nn.Layer):
+ def __init__(self, in_channels, out_channels, last_drop, out_char_num):
+ super().__init__()
+ self.last_conv = nn.Linear(in_channels, out_channels, bias_attr=False)
+ self.hardswish = nn.Hardswish()
+ self.dropout = nn.Dropout(p=last_drop, mode="downscale_in_infer")
+
+ def forward(self, x, sz):
+ x = x.reshape([0, sz[0], sz[1], x.shape[-1]])
+ x = x.mean(1)
+ x = self.last_conv(x)
+ x = self.hardswish(x)
+ x = self.dropout(x)
+ return x, [1, sz[1]]
+
+
+class OutPool(nn.Layer):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, sz):
+ C = x.shape[-1]
+ x = x.transpose([0, 2, 1]).reshape([0, C, sz[0], sz[1]])
+ x = nn.functional.avg_pool2d(x, [sz[0], 2])
+ return x, [1, sz[1] // 2]
+
+
+class Feat2D(nn.Layer):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, sz):
+ C = x.shape[-1]
+ x = x.transpose([0, 2, 1]).reshape([0, C, sz[0], sz[1]])
+ return x, sz
+
+
+class SVTRv2(nn.Layer):
+ def __init__(
+ self,
+ max_sz=[32, 128],
+ in_channels=3,
+ out_channels=192,
+ out_char_num=25,
+ depths=[3, 6, 3],
+ dims=[64, 128, 256],
+ mixer=[["Conv"] * 3, ["Conv"] * 3 + ["Global"] * 3, ["Global"] * 3],
+ use_pos_embed=False,
+ sub_k=[[1, 1], [2, 1], [1, 1]],
+ num_heads=[2, 4, 8],
+ mlp_ratio=4,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.0,
+ last_drop=0.1,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm,
+ act=nn.GELU,
+ last_stage=False,
+ eps=1e-6,
+ use_pool=False,
+ feat2d=False,
+ **kwargs
+ ):
+ super().__init__()
+ num_stages = len(depths)
+ self.num_features = dims[-1]
+
+ feat_max_size = [max_sz[0] // 4, max_sz[1] // 4]
+ self.pope = POPatchEmbed(
+ in_channels=in_channels,
+ feat_max_size=feat_max_size,
+ embed_dim=dims[0],
+ use_pos_embed=use_pos_embed,
+ flatten=mixer[0][0] != "Conv",
+ )
+
+ dpr = np.linspace(0, drop_path_rate, sum(depths)) # stochastic depth decay rule
+
+ self.stages = nn.LayerList()
+ for i_stage in range(num_stages):
+ stage = SVTRStage(
+ dim=dims[i_stage],
+ out_dim=dims[i_stage + 1] if i_stage < num_stages - 1 else 0,
+ depth=depths[i_stage],
+ mixer=mixer[i_stage],
+ sub_k=sub_k[i_stage],
+ num_heads=num_heads[i_stage],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_stage]) : sum(depths[: i_stage + 1])],
+ norm_layer=norm_layer,
+ act=act,
+ downsample=False if i_stage == num_stages - 1 else True,
+ eps=eps,
+ )
+ self.stages.append(stage)
+
+ self.out_channels = self.num_features
+ self.last_stage = last_stage
+ if last_stage:
+ self.out_channels = out_channels
+ self.stages.append(
+ LastStage(self.num_features, out_channels, last_drop, out_char_num)
+ )
+ if use_pool:
+ self.stages.append(OutPool())
+
+ if feat2d:
+ self.stages.append(Feat2D())
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ zeros_(m.bias)
+ elif isinstance(m, nn.LayerNorm):
+ zeros_(m.bias)
+ ones_(m.weight)
+
+ def forward(self, x):
+ x, sz = self.pope(x)
+ for stage in self.stages:
+ x, sz = stage(x, sz)
+ return x
diff --git a/ppocr/modeling/heads/rec_multi_head.py b/ppocr/modeling/heads/rec_multi_head.py
index c7005c108e..50887d7c86 100644
--- a/ppocr/modeling/heads/rec_multi_head.py
+++ b/ppocr/modeling/heads/rec_multi_head.py
@@ -149,5 +149,5 @@ def forward(self, x, targets=None):
head_out["sar"] = sar_out
else:
gtc_out = self.gtc_head(self.before_gtc(x), targets[1:])
- head_out["nrtr"] = gtc_out
+ head_out["gtc"] = gtc_out
return head_out