Skip to content

Commit

Permalink
Improve filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Feb 12, 2025
1 parent ebad33a commit 27f3f7f
Showing 1 changed file with 40 additions and 4 deletions.
44 changes: 40 additions & 4 deletions marker/builders/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Annotated, List, Optional, Tuple

import numpy as np
from ftfy import fix_text

from surya.detection import DetectionPredictor, InlineDetectionPredictor, TextDetectionResult
from surya.ocr_error import OCRErrorPredictor
Expand Down Expand Up @@ -57,6 +58,10 @@ class LineBuilder(BaseBuilder):
"The minimum ratio of pages that must pass the layout coverage check",
"to avoid OCR.",
] = .8
min_ocr_line_pct: Annotated[
float,
"The minimum percentage of lines that need to be OCRed per page for OCR to actually happen."
] = .1
detected_provider_line_overlap: Annotated[
float,
"The maximum overlap between a detected text line and a provider line to consider as a new line"
Expand All @@ -77,6 +82,10 @@ class LineBuilder(BaseBuilder):
float,
"The minimum area for an inline math block, in pixels."
] = 20
inline_math_line_vertical_merge_threshold: Annotated[
int,
"The maximum pixel distance between y1s for two lines to be merged"
] = 5
excluded_for_coverage: Annotated[
Tuple[BlockTypes],
"A list of block types to exclude from the layout coverage check.",
Expand Down Expand Up @@ -227,6 +236,10 @@ def filter_detected_text_lines(
max_intersection = np.max(intersection) / detected_line.area
if max_intersection < self.detected_provider_line_overlap:
filtered_lines.append(detected_line)

# If we have too few OCR lines, we should OCR none of them (assume provider is okay)
if len(filtered_lines) / len(detected_text_lines) < self.min_ocr_line_pct:
filtered_lines = []

return filtered_lines

Expand Down Expand Up @@ -401,11 +414,23 @@ def merge_provider_lines_inline_math(
if max_overlap <= self.line_inline_math_overlap_threshold:
continue

best_overlap = np.argmax(overlaps[i])
best_overlap_line = horizontal_provider_lines[best_overlap]
best_overlap_y1 = best_overlap_line[1].line.polygon.y_start

nonzero_idxs = np.nonzero(overlaps[i] > self.line_inline_math_overlap_threshold)[0]
for idx in nonzero_idxs:
provider_idx, provider_line = horizontal_provider_lines[idx]
line_overlaps = self.check_char_math_overlap(provider_line, math_line_polygon)
if line_overlaps:
provider_line_y1 = provider_line.line.polygon.y_start

remove_overlaps = False
if abs(provider_line_y1 - best_overlap_y1) > self.inline_math_line_vertical_merge_threshold:
remove_overlaps = True

line_overlaps = self.find_overlapping_math_chars(provider_line, math_line_polygon, remove_chars=remove_overlaps)

# Do not merge if too far above/below (but remove characters)
if line_overlaps and not remove_overlaps:
# Add the index of the provider line to the merge line
merge_line.append(provider_idx)

Expand Down Expand Up @@ -439,14 +464,15 @@ def merge_provider_lines_inline_math(
# Combine the spans of the provider line with the merged line
merged_line = merged_line.merge(provider_line)
self.add_math_span_format(merged_line)
already_merged.add(idx) # Prevent double merging
out_provider_lines.append((min_idx, merged_line))

# Sort to preserve original order
out_provider_lines = sorted(out_provider_lines, key=lambda x: x[0])
out_provider_lines = [p for _, p in out_provider_lines]
return out_provider_lines

def check_char_math_overlap(self, provider_line, math_line_polygon):
def find_overlapping_math_chars(self, provider_line, math_line_polygon, remove_chars=False):
# Identify if a character in the provider line overlaps with the inline math line - meaning that the line can be treated as math
spans = provider_line.spans
math_overlaps = False
Expand All @@ -461,8 +487,18 @@ def check_char_math_overlap(self, provider_line, math_line_polygon):
# For providers which surface characters - find line overlap based on characters
assert len(spans) == len(provider_line.chars), "Number of spans and characters in provider line do not match"
for span, span_chars in zip(spans, provider_line.chars):
new_span_chars = []
span_overlaps = False
for char in span_chars:
if char.polygon.intersection_pct(math_line_polygon) >= self.char_inline_math_overlap_threshold:
return True
span_overlaps = True
else:
new_span_chars.append(char)

# Remove stray characters that overlap with math lines
if span_overlaps and remove_chars:
span.text = fix_text(''.join(c.char for c in new_span_chars))

math_overlaps = math_overlaps or span_overlaps

return math_overlaps

0 comments on commit 27f3f7f

Please sign in to comment.