Skip to content

Commit

Permalink
CV套件建设专项活动 - 文字识别返回单字识别坐标 (#10515) (#10537)
Browse files Browse the repository at this point in the history
* modification of return word box

* update_implements

* Update rec_postprocess.py

* Update utility.py
  • Loading branch information
ToddBear authored Aug 10, 2023
1 parent bf6ff0b commit b3f9f68
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 15 deletions.
78 changes: 73 additions & 5 deletions ppocr/postprocess/rec_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,66 @@ def pred_reverse(self, pred):
def add_special_char(self, dict_character):
return dict_character

def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
def get_word_info(self, text, selection):
"""
Group the decoded characters and record the corresponding decoded positions.
Args:
text: the decoded text
selection: the bool array that identifies which columns of features are decoded as non-separated characters
Returns:
word_list: list of the grouped words
word_col_list: list of decoding positions corresponding to each character in the grouped word
state_list: list of marker to identify the type of grouping words, including two types of grouping words:
- 'cn': continous chinese characters (e.g., 你好啊)
- 'en&num': continous english characters (e.g., hello), number (e.g., 123, 1.123), or mixed of them connected by '-' (e.g., VGG-16)
The remaining characters in text are treated as separators between groups (e.g., space, '(', ')', etc.).
"""
state = None
word_content = []
word_col_content = []
word_list = []
word_col_list = []
state_list = []
valid_col = np.where(selection==True)[0]

for c_i, char in enumerate(text):
if '\u4e00' <= char <= '\u9fff':
c_state = 'cn'
elif bool(re.search('[a-zA-Z0-9]', char)):
c_state = 'en&num'
else:
c_state = 'splitter'

if char == '.' and state == 'en&num' and c_i + 1 < len(text) and bool(re.search('[0-9]', text[c_i+1])): # grouping floting number
c_state = 'en&num'
if char == '-' and state == "en&num": # grouping word with '-', such as 'state-of-the-art'
c_state = 'en&num'

if state == None:
state = c_state

if state != c_state:
if len(word_content) != 0:
word_list.append(word_content)
word_col_list.append(word_col_content)
state_list.append(state)
word_content = []
word_col_content = []
state = c_state

if state != "splitter":
word_content.append(char)
word_col_content.append(valid_col[c_i])

if len(word_content) != 0:
word_list.append(word_content)
word_col_list.append(word_col_content)
state_list.append(state)

return word_list, word_col_list, state_list

def decode(self, text_index, text_prob=None, is_remove_duplicate=False, return_word_box=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
Expand Down Expand Up @@ -95,8 +154,12 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):

if self.reverse: # for arabic rec
text = self.pred_reverse(text)

result_list.append((text, np.mean(conf_list).tolist()))

if return_word_box:
word_list, word_col_list, state_list = self.get_word_info(text, selection)
result_list.append((text, np.mean(conf_list).tolist(), [len(text_index[batch_idx]), word_list, word_col_list, state_list]))
else:
result_list.append((text, np.mean(conf_list).tolist()))
return result_list

def get_ignored_tokens(self):
Expand All @@ -111,14 +174,19 @@ def __init__(self, character_dict_path=None, use_space_char=False,
super(CTCLabelDecode, self).__init__(character_dict_path,
use_space_char)

def __call__(self, preds, label=None, *args, **kwargs):
def __call__(self, preds, label=None, return_word_box=False, *args, **kwargs):
if isinstance(preds, tuple) or isinstance(preds, list):
preds = preds[-1]
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True, return_word_box=return_word_box)
if return_word_box:
for rec_idx, rec in enumerate(text):
wh_ratio = kwargs['wh_ratio_list'][rec_idx]
max_wh_ratio = kwargs['max_wh_ratio']
rec[2][0] = rec[2][0]*(wh_ratio/max_wh_ratio)
if label is None:
return text
label = self.decode(label)
Expand Down
26 changes: 19 additions & 7 deletions ppstructure/predict_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from tools.infer.predict_system import TextSystem
from ppstructure.layout.predict_layout import LayoutPredictor
from ppstructure.table.predict_table import TableSystem, to_excel
from ppstructure.utility import parse_args, draw_structure_result
from ppstructure.utility import parse_args, draw_structure_result, cal_ocr_word_box

logger = get_logger()

Expand Down Expand Up @@ -79,6 +79,8 @@ def __init__(self, args):
from ppstructure.kie.predict_kie_token_ser_re import SerRePredictor
self.kie_predictor = SerRePredictor(args)

self.return_word_box = args.return_word_box

def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
time_dict = {
'image_orientation': 0,
Expand Down Expand Up @@ -156,17 +158,27 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
]
res = []
for box, rec_res in zip(filter_boxes, filter_rec_res):
rec_str, rec_conf = rec_res
rec_str, rec_conf = rec_res[0], rec_res[1]
for token in style_token:
if token in rec_str:
rec_str = rec_str.replace(token, '')
if not self.recovery:
box += [x1, y1]
res.append({
'text': rec_str,
'confidence': float(rec_conf),
'text_region': box.tolist()
})
if self.return_word_box:
word_box_content_list, word_box_list = cal_ocr_word_box(rec_str, box, rec_res[2])
res.append({
'text': rec_str,
'confidence': float(rec_conf),
'text_region': box.tolist(),
'text_word': word_box_content_list,
'text_word_region': word_box_list
})
else:
res.append({
'text': rec_str,
'confidence': float(rec_conf),
'text_region': box.tolist()
})
res_list.append({
'type': region['label'].lower(),
'bbox': [x1, y1, x2, y2],
Expand Down
62 changes: 62 additions & 0 deletions ppstructure/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@
import ast
from PIL import Image, ImageDraw, ImageFont
import numpy as np
<<<<<<< HEAD
from tools.infer.utility import draw_ocr_box_txt, str2bool, init_args as infer_args

=======
from tools.infer.utility import draw_ocr_box_txt, str2bool, str2int_tuple, init_args as infer_args
import math
>>>>>>> 1e11f254 (CV套件建设专项活动 - 文字识别返回单字识别坐标 (#10515))

def init_args():
parser = infer_args()
Expand Down Expand Up @@ -152,6 +157,63 @@ def draw_structure_result(image, result, font_path):
txts.append(text_result['text'])
scores.append(text_result['confidence'])

if 'text_word_region' in text_result:
for word_region in text_result['text_word_region']:
char_box = word_region
box_height = int(
math.sqrt((char_box[0][0] - char_box[3][0])**2 + (char_box[0][1] - char_box[3][1])**2))
box_width = int(
math.sqrt((char_box[0][0] - char_box[1][0])**2 + (char_box[0][1] - char_box[1][1])**2))
if box_height == 0 or box_width == 0:
continue
boxes.append(word_region)
txts.append("")
scores.append(1.0)

im_show = draw_ocr_box_txt(
img_layout, boxes, txts, scores, font_path=font_path, drop_score=0)
return im_show

This comment has been minimized.

Copy link
@ArvinCharl

ArvinCharl Aug 30, 2023

当文本有一些倾斜时, cal_ocr_word_box函数并不能正确处理单元格的四个点坐标,建议修改为下:

def cal_ocr_word_box(rec_str, box, rec_word_info):
    """Calculate the detection frame for each word based on the results of recognition and detection of ocr"""

    # 解压包含列数,单词列表,单词列列表和状态列表的信息
    col_num, word_list, word_col_list, state_list = rec_word_info

    # 提取矩形框的坐标: 修改为四个点, 因为倾斜的话, 可能右上和左上的高度有差异, 也就是y_start 不一定是第一个点
    top_left = box[0]
    top_right = box[1]
    bottom_right = box[2]
    bottom_left = box[3]

    # 计算单元格宽度
    cell_width = (top_right[0] - top_left[0]) / col_num

    # 定义相关的存储列表
    word_box_list = []
    word_box_content_list = []
    cn_width_list = []
    cn_col_list = []

    # 遍历每个单词、单词列和状态
    for word, word_col, state in zip(word_list, word_col_list, state_list):
        # 检查单词是否是中文(cn)
        if state == "cn":
            # 计算每个字符的宽度
            if len(word_col) != 1:
                char_seq_length = (word_col[-1] - word_col[0] + 1) * cell_width
                char_width = char_seq_length / (len(word_col) - 1)
                cn_width_list.append(char_width)

            # 将单词列和字符添加到对应的列表中
            cn_col_list += word_col
            word_box_content_list += word

        # 如果单词不是中文
        else:
            # 计算单词的单元格坐标
            cell_x_start = top_left[0] + int(word_col[0] * cell_width)
            cell_x_end = top_left[0] + int((word_col[-1] + 1) * cell_width)

            # 根据斜率来计算每个点y方向偏移量
            slope = (top_left[1] - top_right[1]) / (top_right[0] - top_left[0])
            angle = math.degrees(math.atan(slope))
            # 角度为正,y减;偏移量 为 邻边 * 正切
            if angle > 0:
                # 每个单元格的left和right的y方向偏移量都不同
                # 计算单元格左侧点
                adjacent_cell_y_left = cell_x_start - top_left[0]
                opposite_cell_y_left = adjacent_cell_y_left * math.tan(
                    math.radians(angle)
                )
                cell_y_top_left = top_left[1] - opposite_cell_y_left
                cell_y_bottom_left = bottom_left[1] - opposite_cell_y_left
                # 计算单元格右侧点
                adjacent_cell_y_right = cell_x_end - top_left[0]
                opposite_cell_y_right = adjacent_cell_y_right * math.tan(
                    math.radians(angle)
                )
                cell_y_top_right = top_left[1] - opposite_cell_y_right
                cell_y_bottom_right = bottom_left[1] - opposite_cell_y_right
            # 角度为负,y加
            else:
                # 计算单元格左侧点
                adjacent_cell_y_left = cell_x_start - top_left[0]
                opposite_cell_y_left = adjacent_cell_y_left * math.tan(
                    math.radians(angle)
                )
                cell_y_top_left = top_left[1] + opposite_cell_y_left
                cell_y_bottom_left = bottom_left[1] + opposite_cell_y_left
                # 计算单元格右侧点
                adjacent_cell_y_right = cell_x_end - top_left[0]
                opposite_cell_y_right = adjacent_cell_y_right * math.tan(
                    math.radians(angle)
                )
                cell_y_top_right = top_left[1] + opposite_cell_y_right
                cell_y_bottom_right = bottom_left[1] + opposite_cell_y_right

            # 创建表示单元格坐标的点的列表
            cell = [
                [cell_x_start, cell_y_top_left],
                [cell_x_end, cell_y_top_right],
                [cell_x_end, cell_y_bottom_right],
                [cell_x_start, cell_y_bottom_left],
            ]

            # 将单元格坐标和单词内容添加到各自的列表中
            word_box_list.append(cell)
            word_box_content_list.append("".join(word))

    # 处理识别结果中存在中文字符的情况
    if len(cn_col_list) != 0:
        # 计算平均字符宽度
        if len(cn_width_list) != 0:
            avg_char_width = np.mean(cn_width_list)
        else:
            avg_char_width = (top_right[0] - top_left[0]) / len(rec_str)

        # 遍历每个中文字符的中心点
        for i, center_idx in enumerate(cn_col_list):
            # 计算每个字符的单元格坐标
            center_x = (center_idx + 0.5) * cell_width
            cell_x_start = max(int(center_x - avg_char_width / 2), 0) + top_left[0]
            cell_x_end = (
                min(int(center_x + avg_char_width / 2), top_right[0] - top_left[0])
                + top_left[0]
            )

            # 根据斜率来计算每个点y方向偏移量
            slope = (top_left[1] - top_right[1]) / (top_right[0] - top_left[0])
            angle = math.degrees(math.atan(slope))
            # 角度为正,y减;偏移量 为 邻边 * 正切
            if angle > 0:
                # 每个单元格的left和right的y方向偏移量都不同
                # 计算单元格左侧点
                adjacent_cell_y_left = cell_x_start - top_left[0]
                opposite_cell_y_left = adjacent_cell_y_left * math.tan(
                    math.radians(angle)
                )
                cell_y_top_left = top_left[1] - opposite_cell_y_left
                cell_y_bottom_left = bottom_left[1] - opposite_cell_y_left
                # 计算单元格右侧点
                adjacent_cell_y_right = cell_x_end - top_left[0]
                opposite_cell_y_right = adjacent_cell_y_right * math.tan(
                    math.radians(angle)
                )
                cell_y_top_right = top_left[1] - opposite_cell_y_right
                cell_y_bottom_right = bottom_left[1] - opposite_cell_y_right
            # 角度为负,y加
            else:
                # 计算单元格左侧点
                adjacent_cell_y_left = cell_x_start - top_left[0]
                opposite_cell_y_left = adjacent_cell_y_left * math.tan(
                    math.radians(angle)
                )
                cell_y_top_left = top_left[1] + opposite_cell_y_left
                cell_y_bottom_left = bottom_left[1] + opposite_cell_y_left
                # 计算单元格右侧点
                adjacent_cell_y_right = cell_x_end - top_left[0]
                opposite_cell_y_right = adjacent_cell_y_right * math.tan(
                    math.radians(angle)
                )
                cell_y_top_right = top_left[1] + opposite_cell_y_right
                cell_y_bottom_right = bottom_left[1] + opposite_cell_y_right

            # 创建表示单元格坐标的点的列表
            cell = [
                [cell_x_start, cell_y_top_left],
                [cell_x_end, cell_y_top_right],
                [cell_x_end, cell_y_bottom_right],
                [cell_x_start, cell_y_bottom_left],
            ]

            # 将单元格坐标添加到word_box_list
            word_box_list.append(cell)

    # 返回单词框内容列表和单词框列表
    return word_box_content_list, word_box_list

image


def cal_ocr_word_box(rec_str, box, rec_word_info):
''' Calculate the detection frame for each word based on the results of recognition and detection of ocr'''

col_num, word_list, word_col_list, state_list = rec_word_info
box = box.tolist()
bbox_x_start = box[0][0]
bbox_x_end = box[1][0]
bbox_y_start = box[0][1]
bbox_y_end = box[2][1]

cell_width = (bbox_x_end - bbox_x_start)/col_num

word_box_list = []
word_box_content_list = []
cn_width_list = []
cn_col_list = []
for word, word_col, state in zip(word_list, word_col_list, state_list):
if state == 'cn':
if len(word_col) != 1:
char_seq_length = (word_col[-1] - word_col[0] + 1) * cell_width
char_width = char_seq_length/(len(word_col)-1)
cn_width_list.append(char_width)
cn_col_list += word_col
word_box_content_list += word
else:
cell_x_start = bbox_x_start + int(word_col[0] * cell_width)
cell_x_end = bbox_x_start + int((word_col[-1]+1) * cell_width)
cell = ((cell_x_start, bbox_y_start), (cell_x_end, bbox_y_start), (cell_x_end, bbox_y_end), (cell_x_start, bbox_y_end))
word_box_list.append(cell)
word_box_content_list.append("".join(word))
if len(cn_col_list) != 0:
if len(cn_width_list) != 0:
avg_char_width = np.mean(cn_width_list)
else:
avg_char_width = (bbox_x_end - bbox_x_start)/len(rec_str)
for center_idx in cn_col_list:
center_x = (center_idx+0.5)*cell_width
cell_x_start = max(int(center_x - avg_char_width/2), 0) + bbox_x_start
cell_x_end = min(int(center_x + avg_char_width/2), bbox_x_end-bbox_x_start) + bbox_x_start
cell = ((cell_x_start, bbox_y_start), (cell_x_end, bbox_y_start), (cell_x_end, bbox_y_end), (cell_x_start, bbox_y_end))
word_box_list.append(cell)

return word_box_content_list, word_box_list
10 changes: 8 additions & 2 deletions tools/infer/predict_rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(self, args):
"use_space_char": args.use_space_char
}
self.postprocess_op = build_post_process(postprocess_params)
self.postprocess_params = postprocess_params
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger)
self.benchmark = args.benchmark
Expand All @@ -146,6 +147,7 @@ def __init__(self, args):
],
warmup=0,
logger=logger)
self.return_word_box = args.return_word_box

def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
Expand Down Expand Up @@ -415,11 +417,12 @@ def __call__(self, img_list):
valid_ratios = []
imgC, imgH, imgW = self.rec_image_shape[:3]
max_wh_ratio = imgW / imgH
# max_wh_ratio = 0
wh_ratio_list = []
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
wh_ratio_list.append(wh_ratio)
for ino in range(beg_img_no, end_img_no):
if self.rec_algorithm == "SAR":
norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
Expand Down Expand Up @@ -624,7 +627,10 @@ def __call__(self, img_list):
preds = outputs
else:
preds = outputs[0]
rec_result = self.postprocess_op(preds)
if self.postprocess_params['name'] == 'CTCLabelDecode':
rec_result = self.postprocess_op(preds, return_word_box=self.return_word_box, wh_ratio_list=wh_ratio_list, max_wh_ratio=max_wh_ratio)
else:
rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
if self.benchmark:
Expand Down
2 changes: 1 addition & 1 deletion tools/infer/predict_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __call__(self, img, cls=True):
rec_res)
filter_boxes, filter_rec_res = [], []
for box, rec_result in zip(dt_boxes, rec_res):
text, score = rec_result
text, score = rec_result[0], rec_result[1]
if score >= self.drop_score:
filter_boxes.append(box)
filter_rec_res.append(rec_result)
Expand Down
4 changes: 4 additions & 0 deletions tools/infer/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ def init_args():

parser.add_argument("--show_log", type=str2bool, default=True)
parser.add_argument("--use_onnx", type=str2bool, default=False)

# extended function
parser.add_argument("--return_word_box", type=str2bool, default=False, help='Whether return the bbox of each word (split by space) or chinese character. Only used in ppstructure for layout recovery')

return parser


Expand Down

0 comments on commit b3f9f68

Please sign in to comment.