Skip to content

Commit

Permalink
Fix list issue, general cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Feb 13, 2025
1 parent c180554 commit eabebb2
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 85 deletions.
28 changes: 16 additions & 12 deletions benchmarks/table/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,20 @@ def extract_tables(children: List[JSONBlockOutput]):
tables.extend(extract_tables(child.children))
return tables

def fix_table_html(table_html: str) -> str:
marker_table_soup = BeautifulSoup(table_html, 'html.parser')
tbody = marker_table_soup.find('tbody')
if tbody:
tbody.unwrap()
for th_tag in marker_table_soup.find_all('th'):
th_tag.name = 'td'
for br_tag in marker_table_soup.find_all('br'):
br_tag.replace_with(marker_table_soup.new_string(''))

marker_table_html = str(marker_table_soup)
marker_table_html = marker_table_html.replace("\n", " ") # Fintabnet uses spaces instead of newlines
return marker_table_html


def inference_tables(dataset, use_llm: bool, table_rec_batch_size: int | None, max_rows: int, use_gemini: bool):
models = create_model_dict()
Expand Down Expand Up @@ -154,18 +168,8 @@ def inference_tables(dataset, use_llm: bool, table_rec_batch_size: int | None, m

# marker wraps the table in <tbody> which fintabnet data doesn't
# Fintabnet doesn't use th tags, need to be replaced for fair comparison
marker_table_soup = BeautifulSoup(marker_table.html, 'html.parser')
tbody = marker_table_soup.find('tbody')
if tbody:
tbody.unwrap()
for th_tag in marker_table_soup.find_all('th'):
th_tag.name = 'td'
for br_tag in marker_table_soup.find_all('br'):
br_tag.replace_with(marker_table_soup.new_string(''))

marker_table_html = str(marker_table_soup)
marker_table_html = marker_table_html.replace("\n", " ") # Fintabnet uses spaces instead of newlines
gemini_table_html = gemini_table.replace("\n", " ") # Fintabnet uses spaces instead of newlines
marker_table_html = fix_table_html(marker_table.html)
gemini_table_html = fix_table_html(gemini_table)

results.append({
"marker_table": marker_table_html,
Expand Down
64 changes: 10 additions & 54 deletions marker/builders/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np
from ftfy import fix_text
from PIL import Image
from PIL import Image, ImageDraw

from surya.detection import DetectionPredictor, InlineDetectionPredictor, TextDetectionResult
from surya.ocr_error import OCRErrorPredictor
Expand Down Expand Up @@ -77,7 +77,7 @@ class LineBuilder(BaseBuilder):
inline_math_line_vertical_merge_threshold: Annotated[
int,
"The maximum pixel distance between y1s for two lines to be merged"
] = 5
] = 8
excluded_for_coverage: Annotated[
Tuple[BlockTypes],
"A list of block types to exclude from the layout coverage check.",
Expand All @@ -90,6 +90,7 @@ class LineBuilder(BaseBuilder):
bool,
"Whether to run texify on inline math spans."
] = False
ocr_remove_blocks: Tuple[BlockTypes, ...] = (BlockTypes.Table, BlockTypes.Form, BlockTypes.TableOfContents, BlockTypes.Equation)

def __init__(self, detection_model: DetectionPredictor, inline_detection_model: InlineDetectionPredictor, ocr_error_model: OCRErrorPredictor, config=None):
super().__init__(config)
Expand Down Expand Up @@ -168,7 +169,7 @@ def get_all_lines(self, document: Document, provider: PdfProvider, do_inline_mat
layout_good.append(provider_lines_good)

run_detection = [not good or do_inline_math_detection for good in layout_good]
page_images = [page.get_image(highres=False, remove_tables=not self.enable_table_ocr) for page, good in zip(document.pages, run_detection) if good]
page_images = [page.get_image(highres=False, remove_blocks=self.ocr_remove_blocks) for page, good in zip(document.pages, run_detection) if good]
detection_results, inline_detection_results = self.get_detection_results(page_images, run_detection, do_inline_math_detection)

assert len(detection_results) == len(inline_detection_results) == len(layout_good) == len(document.pages)
Expand All @@ -182,24 +183,6 @@ def get_all_lines(self, document: Document, provider: PdfProvider, do_inline_mat
page_size = provider.get_page_bbox(document_page.page_id).size
image_size = PolygonBox.from_bbox(detection_result.image_bbox).size if detection_result else page_size

# Filter out detected equation blocks
inline_detection_result = self.filter_equation_overlaps(
document,
document_page,
inline_detection_result,
image_size,
page_size,
self.line_inline_math_overlap_threshold
)
detection_result = self.filter_equation_overlaps(
document,
document_page,
detection_result,
image_size,
page_size,
self.line_text_overlap_threshold
)

# Merge text and inline math detection results
merged_detection_boxes = self.determine_math_lines(text_result=detection_result, inline_result=inline_detection_result)
math_detection_boxes = [(i, box) for i, box in enumerate(merged_detection_boxes) if box.math]
Expand Down Expand Up @@ -317,38 +300,11 @@ def merge_blocks(self, document: Document, page_provider_lines: ProviderPageLine
# Text extraction method is overridden later for OCRed documents
document_page.merge_blocks(merged_lines, text_extraction_method='pdftext')

def filter_equation_overlaps(
self,
document,
page: PageGroup,
inline_boxes: TextDetectionResult,
image_size,
page_size,
threshold: float
):
if inline_boxes is None:
return inline_boxes

equations = page.contained_blocks(document, (BlockTypes.Equation,))
equation_boxes = [eq.polygon.bbox for eq in equations]
inline_polys = [PolygonBox(polygon=box.polygon).rescale(image_size, page_size) for box in inline_boxes.bboxes]
inline_bboxes = [poly.bbox for poly in inline_polys]
inline_areas = [poly.area for poly in inline_polys]

if len(equation_boxes) == 0 or len(inline_bboxes) == 0:
return inline_boxes

overlaps = matrix_intersection_area(inline_bboxes, equation_boxes)
overlap_idxs = (np.max(overlaps, axis=-1) / np.array(inline_areas)) > threshold
inline_boxes.bboxes = [ib for i, ib in enumerate(inline_boxes.bboxes) if not overlap_idxs[i]]
return inline_boxes


def determine_math_lines(
self,
text_result: TextDetectionResult,
inline_result: TextDetectionResult,
math_box_padding: float = .05
) -> List[TextBox]:
"""
Marks lines as math if they contain inline math boxes.
Expand Down Expand Up @@ -391,7 +347,7 @@ def determine_math_lines(
continue

# Ignore vertical lines
if max_overlap_box.height > max_overlap_box.width:
if max_overlap_box.height > max_overlap_box.width * 2:
continue

max_overlap_box.math = True
Expand All @@ -406,11 +362,11 @@ def add_math_span_format(self, provider_line):
provider_line.line.formats.append("math")

def merge_provider_lines_inline_math(
self,
provider_lines: List[ProviderOutput],
inline_math_lines: List[TextBox],
image_size,
page_size
self,
provider_lines: List[ProviderOutput],
inline_math_lines: List[TextBox],
image_size,
page_size
):
# When provider lines is empty or no inline math detected, return provider lines
if not provider_lines or not inline_math_lines:
Expand Down
5 changes: 2 additions & 3 deletions marker/builders/llm_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Annotated

from surya.layout import LayoutPredictor
from surya.ocr_error import OCRErrorPredictor
from tqdm import tqdm
from pydantic import BaseModel

Expand Down Expand Up @@ -98,8 +97,8 @@ class LLMLayoutBuilder(LayoutBuilder):
Respond only with one of `Figure`, `Picture`, `ComplexRegion`, `Table`, or `Form`.
"""

def __init__(self, layout_model: LayoutPredictor, ocr_error_model: OCRErrorPredictor, config=None):
super().__init__(layout_model, ocr_error_model, config)
def __init__(self, layout_model: LayoutPredictor, config=None):
super().__init__(layout_model, config)

self.model = GoogleModel(self.google_api_key, self.model_name)

Expand Down
12 changes: 9 additions & 3 deletions marker/processors/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Annotated, TypedDict, List
from typing import Annotated, TypedDict, List, Sequence

from pydantic import BaseModel
from tqdm import tqdm
from PIL import Image

from marker.processors import BaseProcessor
from marker.schema import BlockTypes
from marker.services.google import GoogleModel
from marker.schema.blocks import Block
from marker.schema.document import Document
Expand Down Expand Up @@ -75,8 +76,13 @@ def __init__(self, config=None):

self.model = GoogleModel(self.google_api_key, self.model_name)

def extract_image(self, document: Document, image_block: Block):
return image_block.get_image(document, highres=True, expansion=(self.image_expansion_ratio, self.image_expansion_ratio))
def extract_image(self, document: Document, image_block: Block, remove_blocks: Sequence[BlockTypes] | None = None) -> Image.Image:
return image_block.get_image(
document,
highres=True,
expansion=(self.image_expansion_ratio, self.image_expansion_ratio),
remove_blocks=remove_blocks
)


class BaseLLMComplexBlockProcessor(BaseLLMProcessor):
Expand Down
14 changes: 9 additions & 5 deletions marker/processors/llm/llm_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class LLMTextProcessor(BaseLLMSimpleBlockProcessor):
] = 10

block_types = (BlockTypes.Line,)
image_remove_blocks = (BlockTypes.Equation,)
text_math_rewriting_prompt = """You are a text correction expert specializing in accurately reproducing text from images.
You will receive an image of a text block and a set of extracted lines corresponding to the text in the image.
Your task is to correct any errors in the extracted lines, including math, formatting, and other inaccuracies, and output the corrected lines in a JSON format.
Expand All @@ -35,8 +36,8 @@ class LLMTextProcessor(BaseLLMSimpleBlockProcessor):
* Formatting: Maintain consistent formatting with the text block image, including spacing, indentation, and special characters.
* Other inaccuracies: If the image is handwritten then you may correct any spelling errors, or other discrepancies.
5. Do not remove any formatting i.e bold, italics, math, superscripts, subscripts, etc from the extracted lines unless it is necessary to correct an error.
6. Ensure that inline math is properly with inline math tags.
7. The number of corrected lines in the output MUST equal the number of extracted lines provided in the input. Do not add or remove lines.
6. Ensure that inline math is properly formatted with inline math tags.
7. The number of corrected lines in the output MUST equal the number of extracted lines provided in the input. Do not add or remove lines. There are exactly {line_count} input lines.
8. Output the corrected lines in JSON format with a "lines" field, as shown in the example below.
9. You absolutely cannot remove any <a href='#...'>...</a> tags, those are extremely important for references and are coming directly from the document, you MUST always preserve them.
Expand Down Expand Up @@ -119,9 +120,12 @@ def block_prompts(self, document: Document) -> List[PromptData]:
pages = [b["page"] for b in block_data]
block_lines = [block.formatted_text(document) for block in blocks]

prompt = self.text_math_rewriting_prompt.replace("{extracted_lines}",
json.dumps({"extracted_lines": block_lines}, indent=2))
images = [self.extract_image(document, block) for block in blocks]
prompt = (
self.text_math_rewriting_prompt
.replace("{extracted_lines}",json.dumps({"extracted_lines": block_lines}, indent=2))
.replace("{line_count}", str(len(block_lines)))
)
images = [self.extract_image(document, block, remove_blocks=self.image_remove_blocks) for block in blocks]
image = self.combine_images(images)

prompt_data.append({
Expand Down
4 changes: 2 additions & 2 deletions marker/schema/blocks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ def from_block(cls, block: Block) -> Block:
block_attrs = block.model_dump(exclude=["id", "block_id", "block_type"])
return cls(**block_attrs)

def get_image(self, document: Document, highres: bool = False, expansion: Tuple[float, float] | None = None) -> Image.Image | None:
def get_image(self, document: Document, highres: bool = False, expansion: Tuple[float, float] | None = None, remove_blocks: Sequence[BlockTypes] | None = None) -> Image.Image | None:
image = self.highres_image if highres else self.lowres_image
if image is None:
page = document.get_page(self.page_id)
page_image = page.highres_image if highres else page.lowres_image
page_image = page.get_image(highres=highres, remove_blocks=remove_blocks)

# Scale to the image size
bbox = self.polygon.rescale((page.polygon.width, page.polygon.height), page_image.size)
Expand Down
12 changes: 6 additions & 6 deletions marker/schema/groups/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@ def add_child(self, block: Block):
else:
self.children.append(block)

def get_image(self, *args, highres: bool = False, remove_tables: bool = False, **kwargs):
def get_image(self, *args, highres: bool = False, remove_blocks: Sequence[BlockTypes] | None = None, **kwargs):
image = self.highres_image if highres else self.lowres_image

# Avoid double OCR for tables
if remove_tables:
# Avoid double OCR for certain elements
if remove_blocks:
image = image.copy()
draw = ImageDraw.Draw(image)
table_blocks = [block for block in self.children if block.block_type in (BlockTypes.Table, BlockTypes.Form, BlockTypes.TableOfContents)]
for table_block in table_blocks:
poly = table_block.polygon.rescale(self.polygon.size, image.size).polygon
bad_blocks = [block for block in self.children if block.block_type in remove_blocks]
for bad_block in bad_blocks:
poly = bad_block.polygon.rescale(self.polygon.size, image.size).polygon
poly = [(int(p[0]), int(p[1])) for p in poly]
draw.polygon(poly, fill='white')

Expand Down

0 comments on commit eabebb2

Please sign in to comment.