Skip to content

Commit

Permalink
Fix some perf regresions
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Feb 12, 2025
1 parent dd799b5 commit 9a9a61a
Showing 1 changed file with 43 additions and 56 deletions.
99 changes: 43 additions & 56 deletions marker/builders/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
from ftfy import fix_text
from PIL import Image

from surya.detection import DetectionPredictor, InlineDetectionPredictor, TextDetectionResult
from surya.ocr_error import OCRErrorPredictor
Expand Down Expand Up @@ -53,15 +54,6 @@ class LineBuilder(BaseBuilder):
"The minimum coverage ratio required for the layout model to consider",
"the lines from the PdfProvider valid.",
] = .25
document_ocr_threshold: Annotated[
float,
"The minimum ratio of pages that must pass the layout coverage check",
"to avoid OCR.",
] = .8
min_ocr_line_pct: Annotated[
float,
"The minimum percentage of lines that need to be flagged as needing OCR per page for OCR to actually happen."
] = .9
detected_provider_line_overlap: Annotated[
float,
"The maximum overlap between a detected text line and a provider line to consider as a new line"
Expand Down Expand Up @@ -131,34 +123,63 @@ def get_ocr_error_batch_size(self):
return 4
return 4

def get_all_lines(self, document: Document, provider: PdfProvider, do_inline_math_detection: bool):
page_images = [page.get_image(highres=False, remove_tables=not self.enable_table_ocr) for page in document.pages]
detection_results = self.detection_model(
images=page_images,
def get_detection_results(self, page_images: List[Image.Image], run_detection: List[bool], do_inline_math_detection: bool):
page_detection_results = self.detection_model(
images=[p for p, good in zip(page_images, run_detection) if good],
batch_size=self.get_detection_batch_size()
)
ocr_error_detection_results = self.ocr_error_detection(document.pages, provider.page_lines)
detection_results = []
idx = 0
for good in run_detection:
if good:
detection_results.append(page_detection_results[idx])
idx += 1
else:
detection_results.append(None)
assert idx == len(page_detection_results)

inline_detection_results = [None] * len(page_images)
if do_inline_math_detection:
inline_detection_results = self.inline_detection_model(
images=page_images,
text_boxes=[[b.bbox for b in det_result.bboxes] for det_result in detection_results]
text_boxes=[[b.bbox for b in det_result.bboxes] for det_result in detection_results],
batch_size=self.get_detection_batch_size()
)

return detection_results, inline_detection_results


def get_all_lines(self, document: Document, provider: PdfProvider, do_inline_math_detection: bool):
ocr_error_detection_results = self.ocr_error_detection(document.pages, provider.page_lines)

boxes_to_ocr = {page.page_id: [] for page in document.pages}
page_lines = {page.page_id: [] for page in document.pages}

LineClass: Line = get_block_class(BlockTypes.Line)

for document_page, detection_result, inline_detection_result, ocr_error_detection_label in zip(
layout_good = []
for document_page, ocr_error_detection_label in zip(document.pages, ocr_error_detection_results.labels):
provider_lines: List[ProviderOutput] = provider.page_lines.get(document_page.page_id, [])
provider_lines_good = all([
bool(provider),
ocr_error_detection_label != 'bad',
self.check_layout_coverage(document_page, provider_lines)
])
layout_good.append(provider_lines_good)

run_detection = [not good or do_inline_math_detection for good in layout_good]
page_images = [page.get_image(highres=False, remove_tables=not self.enable_table_ocr) for page, good in zip(document.pages, run_detection) if good]
detection_results, inline_detection_results = self.get_detection_results(page_images, run_detection, do_inline_math_detection)

for document_page, detection_result, inline_detection_result, provider_lines_good in zip(
document.pages,
detection_results,
inline_detection_results,
ocr_error_detection_results.labels
layout_good
):
provider_lines: List[ProviderOutput] = provider.page_lines.get(document_page.page_id, [])
image_size = PolygonBox.from_bbox(detection_result.image_bbox).size
page_size = provider.get_page_bbox(document_page.page_id).size
image_size = PolygonBox.from_bbox(detection_result.image_bbox).size if detection_result else page_size

# Filter out detected equation blocks
inline_detection_result = self.filter_equation_overlaps(
Expand All @@ -182,19 +203,6 @@ def get_all_lines(self, document: Document, provider: PdfProvider, do_inline_mat
merged_detection_boxes = self.determine_math_lines(text_result=detection_result, inline_result=inline_detection_result)
math_detection_boxes = [(i, box) for i, box in enumerate(merged_detection_boxes) if box.math]
nonmath_detection_boxes = [(i, box) for i, box in enumerate(merged_detection_boxes) if not box.math]
text_lines_without_matches = self.filter_detected_text_lines(
provider_lines,
[b for _,b in nonmath_detection_boxes],
image_size,
page_size
)

provider_lines_good = all([
bool(provider),
ocr_error_detection_label != 'bad',
self.check_layout_coverage(document_page, provider_lines),
(len(text_lines_without_matches) / len(nonmath_detection_boxes) < self.min_ocr_line_pct or len(text_lines_without_matches) < 2)
])

if provider_lines_good:
# Merge inline math blocks into the provider lines, only persist new detected text lines which do not overlap with existing provider lines
Expand Down Expand Up @@ -225,7 +233,7 @@ def get_all_lines(self, document: Document, provider: PdfProvider, do_inline_mat
ocr_lines = {document_page.page_id: [] for document_page in document.pages}
for page_id, page_ocr_boxes in boxes_to_ocr.items():
page_size = provider.get_page_bbox(page_id).size
image_size = document.get_page(page_id).get_image(highres=False, remove_tables=not self.enable_table_ocr).size
image_size = document.get_page(page_id).get_image(highres=False).size
for box_to_ocr in page_ocr_boxes:
line_polygon = PolygonBox(polygon=box_to_ocr.polygon).rescale(image_size, page_size)
format = ["math"] if box_to_ocr.math else None
Expand Down Expand Up @@ -256,30 +264,6 @@ def ocr_error_detection(self, pages:List[PageGroup], provider_page_lines: Provid
)
return ocr_error_detection_results

def filter_detected_text_lines(
self,
provider_lines: List[ProviderOutput],
detected_text_lines: List[TextBox],
image_size,
page_size
):
if len(provider_lines) == 0:
return detected_text_lines

if len(detected_text_lines) == 0:
return []

filtered_lines = []
rescaled_line_boxes = [PolygonBox(polygon=line.polygon).rescale(image_size, page_size).bbox for line in detected_text_lines]
provider_line_boxes = [line.line.polygon.bbox for line in provider_lines]
intersections = matrix_intersection_area(rescaled_line_boxes, provider_line_boxes)
for detected_line, intersection in zip(detected_text_lines, intersections):
max_intersection = np.max(intersection) / detected_line.area
if max_intersection < self.detected_provider_line_overlap:
filtered_lines.append(detected_line)

return filtered_lines

def check_layout_coverage(
self,
document_page: PageGroup,
Expand Down Expand Up @@ -369,6 +353,9 @@ def determine_math_lines(
Marks lines as math if they contain inline math boxes.
"""

if not text_result:
return []

text_boxes = [
TextBox(
polygon=box.polygon
Expand Down

0 comments on commit 9a9a61a

Please sign in to comment.