diff --git a/ppstructure/predict_system.py b/ppstructure/predict_system.py index 69f8e27765..152e989c3b 100644 --- a/ppstructure/predict_system.py +++ b/ppstructure/predict_system.py @@ -33,6 +33,7 @@ from ppocr.utils.logging import get_logger from ppocr.utils.visual import draw_ser_results, draw_re_results from tools.infer.predict_system import TextSystem +from tools.infer.predict_rec import TextRecognizer from ppstructure.layout.predict_layout import LayoutPredictor from ppstructure.table.predict_table import TableSystem, to_excel from ppstructure.utility import parse_args, draw_structure_result, cal_ocr_word_box @@ -65,6 +66,7 @@ def __init__(self, args): self.layout_predictor = None self.text_system = None self.table_system = None + self.formula_system = None if args.layout: self.layout_predictor = LayoutPredictor(args) if args.ocr: @@ -78,6 +80,13 @@ def __init__(self, args): ) else: self.table_system = TableSystem(args) + if args.formula: + args_formula = deepcopy(args) + args_formula.rec_algorithm = args.formula_algorithm + args_formula.rec_model_dir = args.formula_model_dir + args_formula.rec_char_dict_path = args.formula_char_dict_path + args_formula.rec_batch_num = args.formula_batch_num + self.formula_system = TextRecognizer(args_formula) elif self.mode == "kie": from ppstructure.kie.predict_kie_token_ser_re import SerRePredictor @@ -92,6 +101,7 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0): "layout": 0, "table": 0, "table_match": 0, + "formula": 0, "det": 0, "rec": 0, "kie": 0, @@ -157,6 +167,12 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0): time_dict["table_match"] += table_time_dict["match"] time_dict["det"] += table_time_dict["det"] time_dict["rec"] += table_time_dict["rec"] + + elif region["label"] == "equation" and self.formula_system is not None: + latex_res, formula_time = self.formula_system([roi_img]) + time_dict["formula"] += formula_time + res = {"latex": latex_res[0]} + else: if text_res is not None: # Filter the text results whose regions intersect with the current layout bbox. @@ -357,6 +373,9 @@ def main(args): sorted_layout_boxes, convert_info_docx, ) + from ppstructure.recovery.recovery_to_markdown import ( + convert_info_markdown, + ) h, w, _ = img.shape res = sorted_layout_boxes(res, w) @@ -365,6 +384,8 @@ def main(args): if args.recovery and all_res != []: try: convert_info_docx(img, all_res, save_folder, img_name) + if args.recovery_to_markdown: + convert_info_markdown(all_res, save_folder, img_name) except Exception as ex: logger.error( "error in layout recovery image:{}, err msg: {}".format( diff --git a/ppstructure/recovery/recovery_to_doc.py b/ppstructure/recovery/recovery_to_doc.py index edbeefd9f2..1974a09d1f 100644 --- a/ppstructure/recovery/recovery_to_doc.py +++ b/ppstructure/recovery/recovery_to_doc.py @@ -67,6 +67,8 @@ def convert_info_docx(img, res, save_folder, img_name): parser = HtmlToDocx() parser.table_style = "TableGrid" parser.handle_table(region["res"]["html"], doc) + elif region["type"] == "equation" and "latex" in region["res"]: + pass else: paragraph = doc.add_paragraph() paragraph_format = paragraph.paragraph_format diff --git a/ppstructure/recovery/recovery_to_markdown.py b/ppstructure/recovery/recovery_to_markdown.py new file mode 100644 index 0000000000..833628e134 --- /dev/null +++ b/ppstructure/recovery/recovery_to_markdown.py @@ -0,0 +1,182 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re + +from ppocr.utils.logging import get_logger + +logger = get_logger() + + +def check_merge_method(in_region): + """Select the function to merge paragraph. + + Determine the paragraph merging method based on the positional + relationship between the text bbox and the first line of text in the text bbox. + + Args: + in_region: Elements with text type in the layout result. + + Returns: + Merge the functions of paragraph, convert_text_space_head or convert_text_space_tail. + """ + text_bbox = in_region["bbox"] + text_x1 = text_bbox[0] + frist_line_box = in_region["res"][0]["text_region"] + point_1 = frist_line_box[0] + point_2 = frist_line_box[2] + frist_line_x1 = point_1[0] + frist_line_height = abs(point_2[1] - point_1[1]) + x1_distance = frist_line_x1 - text_x1 + return ( + convert_text_space_head + if x1_distance > frist_line_height + else convert_text_space_tail + ) + + +def convert_text_space_head(in_region): + """The function to merge paragraph. + + The sign of dividing paragraph is that there are two spaces at the beginning. + + Args: + in_region: Elements with text type in the layout result. + + Returns: + The text content of the current text box. + """ + text = "" + pre_x = None + frist_line = True + for i, res in enumerate(in_region["res"]): + point1 = res["text_region"][0] + point2 = res["text_region"][2] + h = point2[1] - point1[1] + + if i == 0: + text += res["text"] + pre_x = point1[0] + continue + + x1 = point1[0] + if frist_line: + if abs(pre_x - x1) < h: + text += "\n\n" + text += res["text"] + frist_line = True + else: + text += res["text"] + frist_line = False + else: + same_paragh = abs(pre_x - x1) < h + if same_paragh: + text += res["text"] + frist_line = False + else: + text += "\n\n" + text += res["text"] + frist_line = True + pre_x = x1 + return text + + +def convert_text_space_tail(in_region): + """The function to merge paragraph. + + The symbol for dividing paragraph is a space at the end. + + Args: + in_region: Elements with text type in the layout result. + + Returns: + The text content of the current text box. + """ + text = "" + frist_line = True + text_bbox = in_region["bbox"] + width = text_bbox[2] - text_bbox[0] + for i, res in enumerate(in_region["res"]): + point1 = res["text_region"][0] + point2 = res["text_region"][2] + row_width = point2[0] - point1[0] + row_height = point2[1] - point1[1] + full_row_threshold = width - row_height + is_full = row_width >= full_row_threshold + + if frist_line: + text += "\n\n" + text += res["text"] + else: + text += res["text"] + + frist_line = not is_full + return text + + +def convert_info_markdown(res, save_folder, img_name): + """Save the recognition result as a markdown file. + + Args: + res: Recognition result + save_folder: Folder to save the markdown file + img_name: PDF file or image file name + + Returns: + None + """ + + def replace_special_char(content): + special_chars = ["*", "`", "~", "$"] + for char in special_chars: + content = content.replace(char, "\\" + char) + return content + + markdown_string = [] + + for i, region in enumerate(res): + if len(region["res"]) == 0: + continue + img_idx = region["img_idx"] + + if region["type"].lower() == "figure": + img_file_name = "{}_{}.jpg".format(region["bbox"], img_idx) + markdown_string.append( + f"""
\n\t\n
""" + ) + elif region["type"].lower() == "title": + markdown_string.append(f"""# {region["res"][0]["text"]}""") + elif region["type"].lower() == "table": + markdown_string.append(region["res"]["html"]) + elif region["type"].lower() == "header" or region["type"].lower() == "footer": + pass + elif region["type"].lower() == "equation" and "latex" in region["res"]: + markdown_string.append(f"""$${region["res"]["latex"]}$$""") + elif region["type"].lower() == "text": + merge_func = check_merge_method(region) + # logger.warning(f"use merge method:{merge_func.__name__}") + markdown_string.append(replace_special_char(merge_func(region))) + else: + string = "" + for line in region["res"]: + string += line["text"] + " " + markdown_string.append(string) + + md_path = os.path.join(save_folder, "{}_ocr.md".format(img_name)) + markdown_string = "\n\n".join(markdown_string) + markdown_string = re.sub(r"\n{3,}", "\n\n", markdown_string) + with open(md_path, "w", encoding="utf-8") as f: + f.write(markdown_string) + logger.info("markdown save to {}".format(md_path)) diff --git a/ppstructure/utility.py b/ppstructure/utility.py index bffc1fdda0..97e434d6f8 100644 --- a/ppstructure/utility.py +++ b/ppstructure/utility.py @@ -40,6 +40,15 @@ def init_args(): type=str, default="../ppocr/utils/dict/table_structure_dict_ch.txt", ) + # params for formula recognition + parser.add_argument("--formula_algorithm", type=str, default="LaTeXOCR") + parser.add_argument("--formula_model_dir", type=str) + parser.add_argument( + "--formula_char_dict_path", + type=str, + default="../ppocr/utils/dict/latex_ocr_tokenizer.json", + ) + parser.add_argument("--formula_batch_num", type=int, default=1) # params for layout parser.add_argument("--layout_model_dir", type=str) parser.add_argument( @@ -89,6 +98,12 @@ def init_args(): default=True, help="In the forward, whether the table area uses table recognition", ) + parser.add_argument( + "--formula", + type=str2bool, + default=False, + help="Whether to enable formula recognition", + ) parser.add_argument( "--ocr", type=str2bool, @@ -102,6 +117,12 @@ def init_args(): default=False, help="Whether to enable layout of recovery", ) + parser.add_argument( + "--recovery_to_markdown", + type=str2bool, + default=False, + help="Whether to enable layout of recovery to markdown", + ) parser.add_argument( "--use_pdf2docx_api", type=str2bool, @@ -182,7 +203,9 @@ def draw_structure_result(image, result, font_path): (box_layout[0], box_layout[1]), region["type"], fill=text_color, font=font ) - if region["type"] == "table": + if region["type"] == "table" or ( + region["type"] == "equation" and "latex" in region["res"] + ): pass else: for text_result in region["res"]: