Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cherry-pick] #10515 #10537

Merged
merged 1 commit into from
Aug 10, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 73 additions & 5 deletions ppocr/postprocess/rec_postprocess.py
Original file line number Diff line number Diff line change
@@ -67,7 +67,66 @@ def pred_reverse(self, pred):
def add_special_char(self, dict_character):
return dict_character

def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
def get_word_info(self, text, selection):
"""
Group the decoded characters and record the corresponding decoded positions.

Args:
text: the decoded text
selection: the bool array that identifies which columns of features are decoded as non-separated characters
Returns:
word_list: list of the grouped words
word_col_list: list of decoding positions corresponding to each character in the grouped word
state_list: list of marker to identify the type of grouping words, including two types of grouping words:
- 'cn': continous chinese characters (e.g., 你好啊)
- 'en&num': continous english characters (e.g., hello), number (e.g., 123, 1.123), or mixed of them connected by '-' (e.g., VGG-16)
The remaining characters in text are treated as separators between groups (e.g., space, '(', ')', etc.).
"""
state = None
word_content = []
word_col_content = []
word_list = []
word_col_list = []
state_list = []
valid_col = np.where(selection==True)[0]

for c_i, char in enumerate(text):
if '\u4e00' <= char <= '\u9fff':
c_state = 'cn'
elif bool(re.search('[a-zA-Z0-9]', char)):
c_state = 'en&num'
else:
c_state = 'splitter'

if char == '.' and state == 'en&num' and c_i + 1 < len(text) and bool(re.search('[0-9]', text[c_i+1])): # grouping floting number
c_state = 'en&num'
if char == '-' and state == "en&num": # grouping word with '-', such as 'state-of-the-art'
c_state = 'en&num'

if state == None:
state = c_state

if state != c_state:
if len(word_content) != 0:
word_list.append(word_content)
word_col_list.append(word_col_content)
state_list.append(state)
word_content = []
word_col_content = []
state = c_state

if state != "splitter":
word_content.append(char)
word_col_content.append(valid_col[c_i])

if len(word_content) != 0:
word_list.append(word_content)
word_col_list.append(word_col_content)
state_list.append(state)

return word_list, word_col_list, state_list

def decode(self, text_index, text_prob=None, is_remove_duplicate=False, return_word_box=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
@@ -95,8 +154,12 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):

if self.reverse: # for arabic rec
text = self.pred_reverse(text)

result_list.append((text, np.mean(conf_list).tolist()))

if return_word_box:
word_list, word_col_list, state_list = self.get_word_info(text, selection)
result_list.append((text, np.mean(conf_list).tolist(), [len(text_index[batch_idx]), word_list, word_col_list, state_list]))
else:
result_list.append((text, np.mean(conf_list).tolist()))
return result_list

def get_ignored_tokens(self):
@@ -111,14 +174,19 @@ def __init__(self, character_dict_path=None, use_space_char=False,
super(CTCLabelDecode, self).__init__(character_dict_path,
use_space_char)

def __call__(self, preds, label=None, *args, **kwargs):
def __call__(self, preds, label=None, return_word_box=False, *args, **kwargs):
if isinstance(preds, tuple) or isinstance(preds, list):
preds = preds[-1]
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True, return_word_box=return_word_box)
if return_word_box:
for rec_idx, rec in enumerate(text):
wh_ratio = kwargs['wh_ratio_list'][rec_idx]
max_wh_ratio = kwargs['max_wh_ratio']
rec[2][0] = rec[2][0]*(wh_ratio/max_wh_ratio)
if label is None:
return text
label = self.decode(label)
26 changes: 19 additions & 7 deletions ppstructure/predict_system.py
Original file line number Diff line number Diff line change
@@ -34,7 +34,7 @@
from tools.infer.predict_system import TextSystem
from ppstructure.layout.predict_layout import LayoutPredictor
from ppstructure.table.predict_table import TableSystem, to_excel
from ppstructure.utility import parse_args, draw_structure_result
from ppstructure.utility import parse_args, draw_structure_result, cal_ocr_word_box

logger = get_logger()

@@ -79,6 +79,8 @@ def __init__(self, args):
from ppstructure.kie.predict_kie_token_ser_re import SerRePredictor
self.kie_predictor = SerRePredictor(args)

self.return_word_box = args.return_word_box

def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
time_dict = {
'image_orientation': 0,
@@ -156,17 +158,27 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
]
res = []
for box, rec_res in zip(filter_boxes, filter_rec_res):
rec_str, rec_conf = rec_res
rec_str, rec_conf = rec_res[0], rec_res[1]
for token in style_token:
if token in rec_str:
rec_str = rec_str.replace(token, '')
if not self.recovery:
box += [x1, y1]
res.append({
'text': rec_str,
'confidence': float(rec_conf),
'text_region': box.tolist()
})
if self.return_word_box:
word_box_content_list, word_box_list = cal_ocr_word_box(rec_str, box, rec_res[2])
res.append({
'text': rec_str,
'confidence': float(rec_conf),
'text_region': box.tolist(),
'text_word': word_box_content_list,
'text_word_region': word_box_list
})
else:
res.append({
'text': rec_str,
'confidence': float(rec_conf),
'text_region': box.tolist()
})
res_list.append({
'type': region['label'].lower(),
'bbox': [x1, y1, x2, y2],
62 changes: 62 additions & 0 deletions ppstructure/utility.py
Original file line number Diff line number Diff line change
@@ -15,8 +15,13 @@
import ast
from PIL import Image, ImageDraw, ImageFont
import numpy as np
<<<<<<< HEAD
from tools.infer.utility import draw_ocr_box_txt, str2bool, init_args as infer_args

=======
from tools.infer.utility import draw_ocr_box_txt, str2bool, str2int_tuple, init_args as infer_args
import math
>>>>>>> 1e11f254 (CV套件建设专项活动 - 文字识别返回单字识别坐标 (#10515))

def init_args():
parser = infer_args()
@@ -152,6 +157,63 @@ def draw_structure_result(image, result, font_path):
txts.append(text_result['text'])
scores.append(text_result['confidence'])

if 'text_word_region' in text_result:
for word_region in text_result['text_word_region']:
char_box = word_region
box_height = int(
math.sqrt((char_box[0][0] - char_box[3][0])**2 + (char_box[0][1] - char_box[3][1])**2))
box_width = int(
math.sqrt((char_box[0][0] - char_box[1][0])**2 + (char_box[0][1] - char_box[1][1])**2))
if box_height == 0 or box_width == 0:
continue
boxes.append(word_region)
txts.append("")
scores.append(1.0)

im_show = draw_ocr_box_txt(
img_layout, boxes, txts, scores, font_path=font_path, drop_score=0)
return im_show

def cal_ocr_word_box(rec_str, box, rec_word_info):
''' Calculate the detection frame for each word based on the results of recognition and detection of ocr'''

col_num, word_list, word_col_list, state_list = rec_word_info
box = box.tolist()
bbox_x_start = box[0][0]
bbox_x_end = box[1][0]
bbox_y_start = box[0][1]
bbox_y_end = box[2][1]

cell_width = (bbox_x_end - bbox_x_start)/col_num

word_box_list = []
word_box_content_list = []
cn_width_list = []
cn_col_list = []
for word, word_col, state in zip(word_list, word_col_list, state_list):
if state == 'cn':
if len(word_col) != 1:
char_seq_length = (word_col[-1] - word_col[0] + 1) * cell_width
char_width = char_seq_length/(len(word_col)-1)
cn_width_list.append(char_width)
cn_col_list += word_col
word_box_content_list += word
else:
cell_x_start = bbox_x_start + int(word_col[0] * cell_width)
cell_x_end = bbox_x_start + int((word_col[-1]+1) * cell_width)
cell = ((cell_x_start, bbox_y_start), (cell_x_end, bbox_y_start), (cell_x_end, bbox_y_end), (cell_x_start, bbox_y_end))
word_box_list.append(cell)
word_box_content_list.append("".join(word))
if len(cn_col_list) != 0:
if len(cn_width_list) != 0:
avg_char_width = np.mean(cn_width_list)
else:
avg_char_width = (bbox_x_end - bbox_x_start)/len(rec_str)
for center_idx in cn_col_list:
center_x = (center_idx+0.5)*cell_width
cell_x_start = max(int(center_x - avg_char_width/2), 0) + bbox_x_start
cell_x_end = min(int(center_x + avg_char_width/2), bbox_x_end-bbox_x_start) + bbox_x_start
cell = ((cell_x_start, bbox_y_start), (cell_x_end, bbox_y_start), (cell_x_end, bbox_y_end), (cell_x_start, bbox_y_end))
word_box_list.append(cell)

return word_box_content_list, word_box_list
10 changes: 8 additions & 2 deletions tools/infer/predict_rec.py
Original file line number Diff line number Diff line change
@@ -123,6 +123,7 @@ def __init__(self, args):
"use_space_char": args.use_space_char
}
self.postprocess_op = build_post_process(postprocess_params)
self.postprocess_params = postprocess_params
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger)
self.benchmark = args.benchmark
@@ -146,6 +147,7 @@ def __init__(self, args):
],
warmup=0,
logger=logger)
self.return_word_box = args.return_word_box

def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
@@ -415,11 +417,12 @@ def __call__(self, img_list):
valid_ratios = []
imgC, imgH, imgW = self.rec_image_shape[:3]
max_wh_ratio = imgW / imgH
# max_wh_ratio = 0
wh_ratio_list = []
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
wh_ratio_list.append(wh_ratio)
for ino in range(beg_img_no, end_img_no):
if self.rec_algorithm == "SAR":
norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
@@ -624,7 +627,10 @@ def __call__(self, img_list):
preds = outputs
else:
preds = outputs[0]
rec_result = self.postprocess_op(preds)
if self.postprocess_params['name'] == 'CTCLabelDecode':
rec_result = self.postprocess_op(preds, return_word_box=self.return_word_box, wh_ratio_list=wh_ratio_list, max_wh_ratio=max_wh_ratio)
else:
rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
if self.benchmark:
2 changes: 1 addition & 1 deletion tools/infer/predict_system.py
Original file line number Diff line number Diff line change
@@ -101,7 +101,7 @@ def __call__(self, img, cls=True):
rec_res)
filter_boxes, filter_rec_res = [], []
for box, rec_result in zip(dt_boxes, rec_res):
text, score = rec_result
text, score = rec_result[0], rec_result[1]
if score >= self.drop_score:
filter_boxes.append(box)
filter_rec_res.append(rec_result)
4 changes: 4 additions & 0 deletions tools/infer/utility.py
Original file line number Diff line number Diff line change
@@ -145,6 +145,10 @@ def init_args():

parser.add_argument("--show_log", type=str2bool, default=True)
parser.add_argument("--use_onnx", type=str2bool, default=False)

# extended function
parser.add_argument("--return_word_box", type=str2bool, default=False, help='Whether return the bbox of each word (split by space) or chinese character. Only used in ppstructure for layout recovery')

return parser