Skip to content

Commit

Permalink
Fix some issues with LaTeXOCR in paddleX (PaddlePaddle#13646)
Browse files Browse the repository at this point in the history
* repair_some_Bug_for_paddlex

* style2

* style2

* add_epilson_for groupnorm
  • Loading branch information
liuhongen1234567 authored Aug 14, 2024
1 parent 6dc0211 commit 5f0b90a
Show file tree
Hide file tree
Showing 10 changed files with 55 additions and 92 deletions.
12 changes: 6 additions & 6 deletions docs/algorithm/formula_recognition/algorithm_rec_latex_ocr.en.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Please refer to ["Environment Preparation"](../../ppocr/environment.en.md) to co

Furthermore, additional dependencies need to be installed:
```shell
pip install "tokenizers==0.19.1" "imagesize"
pip install -r docs/algorithm/formula_recognition/requirements.txt
```

## 3. Model Training / Evaluation / Prediction
Expand Down Expand Up @@ -61,16 +61,16 @@ Evaluation:
```
# GPU evaluation
# Validation set evaluation
python3 tools/eval.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Metric.cal_blue_score=True
python3 tools/eval.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams
# Test set evaluation
python3 tools/eval.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Metric.cal_blue_score=True Eval.dataset.data=./train_data/LaTeXOCR/latexocr_test.pkl
python3 tools/eval.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Eval.dataset.data=./train_data/LaTeXOCR/latexocr_test.pkl
```

Prediction:

```
# The configuration file used for prediction must match the training
python3 tools/infer_rec.py -c configs/rec/rec_latex_ocr.yml -o Architecture.Backbone.is_predict=True Architecture.Backbone.is_export=True Architecture.Head.is_export=True Global.infer_img='./doc/datasets/pme_demo/0000013.png' Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams
python3 tools/infer_rec.py -c configs/rec/rec_latex_ocr.yml -o Global.infer_img='./docs/datasets/images/pme_demo/0000013.png' Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams
```

## 4. Inference and Deployment
Expand All @@ -79,15 +79,15 @@ python3 tools/infer_rec.py -c configs/rec/rec_latex_ocr.yml -o Architecture.Ba
First, the model saved during the LaTeX-OCR printed mathematical expression recognition training process is converted into an inference model. you can use the following command to convert:

```
python3 tools/export_model.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Global.save_inference_dir=./inference/rec_latex_ocr_infer/ Architecture.Backbone.is_predict=True Architecture.Backbone.is_export=True Architecture.Head.is_export=True
python3 tools/export_model.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Global.save_inference_dir=./inference/rec_latex_ocr_infer/
# The default output max length of the model is 512.
```

For LaTeX-OCR printed mathematical expression recognition model inference, the following commands can be executed:

```
python3 tools/infer/predict_rec.py --image_dir='./doc/datasets/pme_demo/0000295.png' --rec_algorithm="LaTeXOCR" --rec_batch_num=1 --rec_model_dir="./inference/rec_latex_ocr_infer/" --rec_char_dict_path="./ppocr/utils/dict/latex_ocr_tokenizer.json"
python3 tools/infer/predict_rec.py --image_dir='./docs/datasets/images/pme_demo/0000295.png' --rec_algorithm="LaTeXOCR" --rec_batch_num=1 --rec_model_dir="./inference/rec_latex_ocr_infer/" --rec_char_dict_path="./ppocr/utils/dict/latex_ocr_tokenizer.json"
```

### 4.2 C++ Inference
Expand Down
12 changes: 6 additions & 6 deletions docs/algorithm/formula_recognition/algorithm_rec_latex_ocr.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

此外,需要安装额外的依赖:
```shell
pip install "tokenizers==0.19.1" "imagesize"
pip install -r docs/algorithm/formula_recognition/requirements.txt
```

## 3. 模型训练、评估、预测
Expand Down Expand Up @@ -69,17 +69,17 @@ python3 tools/train.py -c configs/rec/rec_latex_ocr.yml -o Global.eval_batch_ste
```shell
# 注意将pretrained_model的路径设置为本地路径。若使用自行训练保存的模型,请注意修改路径和文件名为{path/to/weights}/{model_name}。
# 验证集评估
python3 tools/eval.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Metric.cal_blue_score=True
python3 tools/eval.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams
# 测试集评估
python3 tools/eval.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Metric.cal_blue_score=True Eval.dataset.data=./train_data/LaTeXOCR/latexocr_test.pkl
python3 tools/eval.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Eval.dataset.data=./train_data/LaTeXOCR/latexocr_test.pkl
```

### 3.4 预测

使用如下命令进行单张图片预测:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/infer_rec.py -c configs/rec/rec_latex_ocr.yml -o Architecture.Backbone.is_predict=True Architecture.Backbone.is_export=True Architecture.Head.is_export=True Global.infer_img='./doc/datasets/pme_demo/0000013.png' Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams
python3 tools/infer_rec.py -c configs/rec/rec_latex_ocr.yml -o Global.infer_img='./docs/datasets/images/pme_demo/0000013.png' Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams
# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/datasets/pme_demo/'。
```

Expand All @@ -90,7 +90,7 @@ python3 tools/infer_rec.py -c configs/rec/rec_latex_ocr.yml -o Architecture.Ba

```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/export_model.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Global.save_inference_dir=./inference/rec_latex_ocr_infer/ Architecture.Backbone.is_predict=True Architecture.Backbone.is_export=True Architecture.Head.is_export=True
python3 tools/export_model.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Global.save_inference_dir=./inference/rec_latex_ocr_infer/

# 目前的静态图模型支持的最大输出长度为512
```
Expand All @@ -109,7 +109,7 @@ python3 tools/export_model.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrai
执行如下命令进行模型推理:

```shell
python3 tools/infer/predict_rec.py --image_dir='./doc/datasets/pme_demo/0000295.png' --rec_algorithm="LaTeXOCR" --rec_batch_num=1 --rec_model_dir="./inference/rec_latex_ocr_infer/" --rec_char_dict_path="./ppocr/utils/dict/latex_ocr_tokenizer.json"
python3 tools/infer/predict_rec.py --image_dir='./docs/datasets/images/pme_demo/0000295.png' --rec_algorithm="LaTeXOCR" --rec_batch_num=1 --rec_model_dir="./inference/rec_latex_ocr_infer/" --rec_char_dict_path="./ppocr/utils/dict/latex_ocr_tokenizer.json"

# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/datasets/pme_demo/'。
```
Expand Down
2 changes: 2 additions & 0 deletions docs/algorithm/formula_recognition/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
tokenizers==0.19.1
imagesize
7 changes: 4 additions & 3 deletions ppocr/metrics/rec_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,14 +261,15 @@ def get_metric(self):
self.reset()
if self.cal_blue_score:
return {
"blue_score ": cur_blue_score,
"edit distance ": cur_edit_distance,
"exp_rate ": cur_exp_rate,
"blue_score": cur_blue_score,
"edit distance": cur_edit_distance,
"exp_rate": cur_exp_rate,
"exp_rate<=1 ": cur_exp_1,
"exp_rate<=2 ": cur_exp_2,
"exp_rate<=3 ": cur_exp_3,
}
else:

return {
"edit distance": cur_edit_distance,
"exp_rate": cur_exp_rate,
Expand Down
73 changes: 6 additions & 67 deletions ppocr/modeling/backbones/rec_resnetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,71 +540,6 @@ def forward(self, x):
return x * self.weight.reshape([1, -1, 1, 1]) + self.bias.reshape([1, -1, 1, 1])


from paddle.common_ops_import import (
LayerHelper,
check_type,
check_variable_and_dtype,
)


def group_norm(
input,
groups,
epsilon=1e-05,
weight=None,
bias=None,
act=None,
data_layout="NCHW",
name=None,
):
helper = LayerHelper("group_norm", **locals())
dtype = helper.input_dtype()
check_variable_and_dtype(
input,
"input",
["float16", "uint16", "float32", "float64"],
"group_norm",
)
# create intput and parameters
inputs = {"X": input}
input_shape = input.shape
if len(input_shape) < 2:
raise ValueError(
f"The dimensions of Op(static.nn.group_norm)'s input should be more than 1. But received {len(input_shape)}"
)
if data_layout != "NCHW" and data_layout != "NHWC":
raise ValueError(
"Param(data_layout) of Op(static.nn.group_norm) got wrong value: received "
+ data_layout
+ " but only NCHW or NHWC supported."
)
channel_num = input_shape[1] if data_layout == "NCHW" else input_shape[-1]
param_shape = [channel_num]
inputs["Scale"] = weight
inputs["Bias"] = bias
# create output
mean_out = helper.create_variable(dtype=dtype, stop_gradient=True)
variance_out = helper.create_variable(dtype=dtype, stop_gradient=True)
group_norm_out = helper.create_variable(dtype=dtype)

helper.append_op(
type="group_norm",
inputs=inputs,
outputs={
"Y": group_norm_out,
"Mean": mean_out,
"Variance": variance_out,
},
attrs={
"epsilon": epsilon,
"groups": groups,
"data_layout": data_layout,
},
)

return helper.append_activation(group_norm_out)


class GroupNormAct(nn.GroupNorm):
# NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args
def __init__(
Expand All @@ -630,8 +565,12 @@ def __init__(
self.act = nn.Identity()

def forward(self, x):
x = group_norm(
x, self._num_groups, self._epsilon, weight=self.weight, bias=self.bias
x = F.group_norm(
x,
num_groups=self._num_groups,
epsilon=self._epsilon,
weight=self.weight,
bias=self.bias,
)
x = self.act(x)
return x
Expand Down
5 changes: 1 addition & 4 deletions ppocr/modeling/heads/rec_latexocr_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,10 +977,7 @@ def generate_export(
paddle.cumsum((out == eos_token).cast(paddle.int64), 1)[:, -1] >= 1
).all()
):
out = out[:, t:]
if num_dims == 1:
out = out.squeeze(0)
return out
break
i_idx += 1
out = out[:, t:]
if num_dims == 1:
Expand Down
5 changes: 4 additions & 1 deletion ppocr/utils/formula_utils/math_txt2pkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pickle
from tqdm import tqdm
import os
import math
from paddle.utils import try_import
from collections import defaultdict
import glob
Expand All @@ -41,7 +42,9 @@ def txt2pickle(images, equations, save_dir):
min_dimensions[0] <= width <= max_dimensions[0]
and min_dimensions[1] <= height <= max_dimensions[1]
):
data[(width, height)].append((eqs[indices[i]], im))
divide_h = math.ceil(height / 16) * 16
divide_w = math.ceil(width / 16) * 16
data[(divide_w, divide_h)].append((eqs[indices[i]], im))
data = dict(data)
with open(save_p, "wb") as file:
pickle.dump(data, file)
Expand Down
1 change: 1 addition & 0 deletions tools/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def main():
model_type = "can"
elif config["Architecture"]["algorithm"] == "LaTeXOCR":
model_type = "latexocr"
config["Metric"]["cal_blue_score"] = True
else:
model_type = config["Architecture"]["model_type"]
else:
Expand Down
24 changes: 19 additions & 5 deletions tools/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import argparse

import yaml
import json
import paddle
from paddle.jit import to_static
from collections import OrderedDict
Expand Down Expand Up @@ -219,11 +220,18 @@ def dump_infer_config(config, path, logger):
for k, v in config["PostProcess"].items():
postprocess[k] = v

if config["Global"].get("character_dict_path") is not None:
with open(config["Global"]["character_dict_path"], encoding="utf-8") as f:
lines = f.readlines()
character_dict = [line.strip("\n") for line in lines]
postprocess["character_dict"] = character_dict
if config["Architecture"].get("algorithm") in ["LaTeXOCR"]:
tokenizer_file = config["Global"].get("rec_char_dict_path")
if tokenizer_file is not None:
with open(tokenizer_file, encoding="utf-8") as tokenizer_config_handle:
character_dict = json.load(tokenizer_config_handle)
postprocess["character_dict"] = character_dict
else:
if config["Global"].get("character_dict_path") is not None:
with open(config["Global"]["character_dict_path"], encoding="utf-8") as f:
lines = f.readlines()
character_dict = [line.strip("\n") for line in lines]
postprocess["character_dict"] = character_dict

infer_cfg["PostProcess"] = postprocess

Expand Down Expand Up @@ -288,6 +296,12 @@ def main():
# for sr algorithm
if config["Architecture"]["model_type"] == "sr":
config["Architecture"]["Transform"]["infer_mode"] = True

# for latexocr algorithm
if config["Architecture"].get("algorithm") in ["LaTeXOCR"]:
config["Architecture"]["Backbone"]["is_predict"] = True
config["Architecture"]["Backbone"]["is_export"] = True
config["Architecture"]["Head"]["is_export"] = True
model = build_model(config["Architecture"])
load_model(config, model, model_type=config["Architecture"]["model_type"])
model.eval()
Expand Down
6 changes: 6 additions & 0 deletions tools/infer_rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ def main():
config["Architecture"]["Head"]["out_channels_list"] = out_channels_list
else: # base rec model
config["Architecture"]["Head"]["out_channels"] = char_num

if config["Architecture"].get("algorithm") in ["LaTeXOCR"]:
config["Architecture"]["Backbone"]["is_predict"] = True
config["Architecture"]["Backbone"]["is_export"] = True
config["Architecture"]["Head"]["is_export"] = True

model = build_model(config["Architecture"])

load_model(config, model)
Expand Down

0 comments on commit 5f0b90a

Please sign in to comment.