From 9a9a61a8908afc4202a057391f49173f922433c8 Mon Sep 17 00:00:00 2001 From: Vik Paruchuri Date: Wed, 12 Feb 2025 16:47:35 -0500 Subject: [PATCH] Fix some perf regresions --- marker/builders/line.py | 99 ++++++++++++++++++----------------------- 1 file changed, 43 insertions(+), 56 deletions(-) diff --git a/marker/builders/line.py b/marker/builders/line.py index cd318673..a7388873 100644 --- a/marker/builders/line.py +++ b/marker/builders/line.py @@ -3,6 +3,7 @@ import numpy as np from ftfy import fix_text +from PIL import Image from surya.detection import DetectionPredictor, InlineDetectionPredictor, TextDetectionResult from surya.ocr_error import OCRErrorPredictor @@ -53,15 +54,6 @@ class LineBuilder(BaseBuilder): "The minimum coverage ratio required for the layout model to consider", "the lines from the PdfProvider valid.", ] = .25 - document_ocr_threshold: Annotated[ - float, - "The minimum ratio of pages that must pass the layout coverage check", - "to avoid OCR.", - ] = .8 - min_ocr_line_pct: Annotated[ - float, - "The minimum percentage of lines that need to be flagged as needing OCR per page for OCR to actually happen." - ] = .9 detected_provider_line_overlap: Annotated[ float, "The maximum overlap between a detected text line and a provider line to consider as a new line" @@ -131,34 +123,63 @@ def get_ocr_error_batch_size(self): return 4 return 4 - def get_all_lines(self, document: Document, provider: PdfProvider, do_inline_math_detection: bool): - page_images = [page.get_image(highres=False, remove_tables=not self.enable_table_ocr) for page in document.pages] - detection_results = self.detection_model( - images=page_images, + def get_detection_results(self, page_images: List[Image.Image], run_detection: List[bool], do_inline_math_detection: bool): + page_detection_results = self.detection_model( + images=[p for p, good in zip(page_images, run_detection) if good], + batch_size=self.get_detection_batch_size() ) - ocr_error_detection_results = self.ocr_error_detection(document.pages, provider.page_lines) + detection_results = [] + idx = 0 + for good in run_detection: + if good: + detection_results.append(page_detection_results[idx]) + idx += 1 + else: + detection_results.append(None) + assert idx == len(page_detection_results) inline_detection_results = [None] * len(page_images) if do_inline_math_detection: inline_detection_results = self.inline_detection_model( images=page_images, - text_boxes=[[b.bbox for b in det_result.bboxes] for det_result in detection_results] + text_boxes=[[b.bbox for b in det_result.bboxes] for det_result in detection_results], + batch_size=self.get_detection_batch_size() ) + return detection_results, inline_detection_results + + + def get_all_lines(self, document: Document, provider: PdfProvider, do_inline_math_detection: bool): + ocr_error_detection_results = self.ocr_error_detection(document.pages, provider.page_lines) + boxes_to_ocr = {page.page_id: [] for page in document.pages} page_lines = {page.page_id: [] for page in document.pages} LineClass: Line = get_block_class(BlockTypes.Line) - for document_page, detection_result, inline_detection_result, ocr_error_detection_label in zip( + layout_good = [] + for document_page, ocr_error_detection_label in zip(document.pages, ocr_error_detection_results.labels): + provider_lines: List[ProviderOutput] = provider.page_lines.get(document_page.page_id, []) + provider_lines_good = all([ + bool(provider), + ocr_error_detection_label != 'bad', + self.check_layout_coverage(document_page, provider_lines) + ]) + layout_good.append(provider_lines_good) + + run_detection = [not good or do_inline_math_detection for good in layout_good] + page_images = [page.get_image(highres=False, remove_tables=not self.enable_table_ocr) for page, good in zip(document.pages, run_detection) if good] + detection_results, inline_detection_results = self.get_detection_results(page_images, run_detection, do_inline_math_detection) + + for document_page, detection_result, inline_detection_result, provider_lines_good in zip( document.pages, detection_results, inline_detection_results, - ocr_error_detection_results.labels + layout_good ): provider_lines: List[ProviderOutput] = provider.page_lines.get(document_page.page_id, []) - image_size = PolygonBox.from_bbox(detection_result.image_bbox).size page_size = provider.get_page_bbox(document_page.page_id).size + image_size = PolygonBox.from_bbox(detection_result.image_bbox).size if detection_result else page_size # Filter out detected equation blocks inline_detection_result = self.filter_equation_overlaps( @@ -182,19 +203,6 @@ def get_all_lines(self, document: Document, provider: PdfProvider, do_inline_mat merged_detection_boxes = self.determine_math_lines(text_result=detection_result, inline_result=inline_detection_result) math_detection_boxes = [(i, box) for i, box in enumerate(merged_detection_boxes) if box.math] nonmath_detection_boxes = [(i, box) for i, box in enumerate(merged_detection_boxes) if not box.math] - text_lines_without_matches = self.filter_detected_text_lines( - provider_lines, - [b for _,b in nonmath_detection_boxes], - image_size, - page_size - ) - - provider_lines_good = all([ - bool(provider), - ocr_error_detection_label != 'bad', - self.check_layout_coverage(document_page, provider_lines), - (len(text_lines_without_matches) / len(nonmath_detection_boxes) < self.min_ocr_line_pct or len(text_lines_without_matches) < 2) - ]) if provider_lines_good: # Merge inline math blocks into the provider lines, only persist new detected text lines which do not overlap with existing provider lines @@ -225,7 +233,7 @@ def get_all_lines(self, document: Document, provider: PdfProvider, do_inline_mat ocr_lines = {document_page.page_id: [] for document_page in document.pages} for page_id, page_ocr_boxes in boxes_to_ocr.items(): page_size = provider.get_page_bbox(page_id).size - image_size = document.get_page(page_id).get_image(highres=False, remove_tables=not self.enable_table_ocr).size + image_size = document.get_page(page_id).get_image(highres=False).size for box_to_ocr in page_ocr_boxes: line_polygon = PolygonBox(polygon=box_to_ocr.polygon).rescale(image_size, page_size) format = ["math"] if box_to_ocr.math else None @@ -256,30 +264,6 @@ def ocr_error_detection(self, pages:List[PageGroup], provider_page_lines: Provid ) return ocr_error_detection_results - def filter_detected_text_lines( - self, - provider_lines: List[ProviderOutput], - detected_text_lines: List[TextBox], - image_size, - page_size - ): - if len(provider_lines) == 0: - return detected_text_lines - - if len(detected_text_lines) == 0: - return [] - - filtered_lines = [] - rescaled_line_boxes = [PolygonBox(polygon=line.polygon).rescale(image_size, page_size).bbox for line in detected_text_lines] - provider_line_boxes = [line.line.polygon.bbox for line in provider_lines] - intersections = matrix_intersection_area(rescaled_line_boxes, provider_line_boxes) - for detected_line, intersection in zip(detected_text_lines, intersections): - max_intersection = np.max(intersection) / detected_line.area - if max_intersection < self.detected_provider_line_overlap: - filtered_lines.append(detected_line) - - return filtered_lines - def check_layout_coverage( self, document_page: PageGroup, @@ -369,6 +353,9 @@ def determine_math_lines( Marks lines as math if they contain inline math boxes. """ + if not text_result: + return [] + text_boxes = [ TextBox( polygon=box.polygon