Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Feb 13, 2025
1 parent ab910fa commit 42d8ab9
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 38 deletions.
29 changes: 15 additions & 14 deletions marker/builders/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,6 @@ class LineBuilder(BaseBuilder):
"The minimum coverage ratio required for the layout model to consider",
"the lines from the PdfProvider valid.",
] = .25
detected_provider_line_overlap: Annotated[
float,
"The maximum overlap between a detected text line and a provider line to consider as a new line"
] = .3
span_inline_math_overlap_threshold: Annotated[
float,
"The minimum overlap of a span with an inline math box to consider for removal"
Expand Down Expand Up @@ -124,29 +120,33 @@ def get_ocr_error_batch_size(self):
return 4

def get_detection_results(self, page_images: List[Image.Image], run_detection: List[bool], do_inline_math_detection: bool):
page_images = [p for p, good in zip(page_images, run_detection) if good]
page_detection_results = self.detection_model(
images=[p for p, good in zip(page_images, run_detection) if good],
images=page_images,
batch_size=self.get_detection_batch_size()
)
inline_detection_results = [None] * len(page_detection_results)
if do_inline_math_detection:
inline_detection_results = self.inline_detection_model(
images=page_images,
text_boxes=[[b.bbox for b in det_result.bboxes] for det_result in page_detection_results],
batch_size=self.get_detection_batch_size()
)

detection_results = []
inline_results = []
idx = 0
for good in run_detection:
if good:
detection_results.append(page_detection_results[idx])
inline_results.append(inline_detection_results[idx])
idx += 1
else:
detection_results.append(None)
inline_results.append(None)
assert idx == len(page_images)

inline_detection_results = [None] * len(run_detection)
if do_inline_math_detection:
inline_detection_results = self.inline_detection_model(
images=page_images,
text_boxes=[[b.bbox for b in det_result.bboxes] for det_result in detection_results],
batch_size=self.get_detection_batch_size()
)

return detection_results, inline_detection_results
return detection_results, inline_results


def get_all_lines(self, document: Document, provider: PdfProvider, do_inline_math_detection: bool):
Expand All @@ -171,6 +171,7 @@ def get_all_lines(self, document: Document, provider: PdfProvider, do_inline_mat
page_images = [page.get_image(highres=False, remove_tables=not self.enable_table_ocr) for page, good in zip(document.pages, run_detection) if good]
detection_results, inline_detection_results = self.get_detection_results(page_images, run_detection, do_inline_math_detection)

assert len(detection_results) == len(inline_detection_results) == len(layout_good) == len(document.pages)
for document_page, detection_result, inline_detection_result, provider_lines_good in zip(
document.pages,
detection_results,
Expand Down
3 changes: 2 additions & 1 deletion marker/converters/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ def build_document(self, filepath: str):
ocr_builder = self.resolve_dependencies(OcrBuilder)
with provider_cls(filepath, self.config) as provider:
document = DocumentBuilder(self.config)(provider, layout_builder, line_builder, ocr_builder)
StructureBuilder(self.config)(document)
structure_builder_cls = self.resolve_dependencies(StructureBuilder)
structure_builder_cls(document)

for processor in self.processor_list:
processor(document)
Expand Down
1 change: 1 addition & 0 deletions marker/processors/llm/llm_handwriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def inference_blocks(self, document: Document) -> List[BlockData]:
for block_data in blocks:
raw_text = block_data["block"].raw_text(document)
block = block_data["block"]

# Don't process text blocks that contain lines already
if block.block_type == BlockTypes.Text:
lines = block.contained_blocks(document, (BlockTypes.Line,))
Expand Down
1 change: 0 additions & 1 deletion tests/builders/test_blank_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,5 @@ def test_blank_page(config, pdf_provider, layout_model, ocr_error_model, recogni
layout_builder.add_blocks_to_pages(document.pages, layout_results)
line_builder.merge_blocks(document, provider_lines, ocr_lines)


assert all([isinstance(p.children, list) for p in document.pages])
assert all([isinstance(p.structure, list) for p in document.pages])
2 changes: 1 addition & 1 deletion tests/builders/test_garbled_pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_garbled_pdf(pdf_document, detection_model, recognition_model, table_rec

table_cell = pdf_document.pages[0].get_block(table_block.structure[0])
assert table_cell.block_type == BlockTypes.Line
assert table_cell.structure is None
assert table_cell.structure[0] == "/page/0/Span/2"

# We don't OCR in the initial pass, only with the TableProcessor
processor = TableProcessor(detection_model, recognition_model, table_rec_model)
Expand Down
48 changes: 48 additions & 0 deletions tests/processors/test_inline_math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from unittest.mock import Mock

import pytest

from marker.processors.llm.llm_meta import LLMSimpleBlockMetaProcessor
from marker.processors.llm.llm_text import LLMTextProcessor
from marker.schema import BlockTypes


@pytest.mark.filename("adversarial.pdf")
@pytest.mark.config({"page_range": [0], "use_llm": True})
def test_llm_text_processor(pdf_document, mocker):
# Get all inline math lines
text_lines = pdf_document.contained_blocks((BlockTypes.Line,))
text_lines = [line for line in text_lines if line.formats and "math" in line.formats]
assert len(text_lines) == 3
corrected_lines = ["<math>Text</math>"] * len(text_lines)

mock_cls = Mock()
mock_cls.return_value.generate_response.return_value = {"corrected_lines": corrected_lines}
mocker.patch("marker.processors.llm.GoogleModel", mock_cls)

config = {"use_llm": True, "google_api_key": "test"}
processor_lst = [LLMTextProcessor(config)]
processor = LLMSimpleBlockMetaProcessor(processor_lst, config)
processor(pdf_document)

contained_spans = text_lines[0].contained_blocks(pdf_document, (BlockTypes.Span,))
assert contained_spans[0].text == "Text\n" # Newline inserted at end of line
assert contained_spans[0].formats == ["math"]


@pytest.mark.filename("adversarial.pdf")
@pytest.mark.config({"page_range": [0]})
def test_llm_text_processor_disabled(pdf_document):
# Get all inline math lines
text_lines = pdf_document.contained_blocks((BlockTypes.Line,))
text_lines = [line for line in text_lines if line.formats and "math" in line.formats]
assert len(text_lines) == 0


@pytest.mark.filename("adversarial.pdf")
@pytest.mark.config({"page_range": [0], "texify_inline_spans": True})
def test_llm_text_processor_texify(pdf_document):
# Get all inline math lines
text_lines = pdf_document.contained_blocks((BlockTypes.Line,))
text_lines = [line for line in text_lines if line.formats and "math" in line.formats]
assert len(text_lines) == 3
22 changes: 1 addition & 21 deletions tests/processors/test_llm_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,24 +182,4 @@ def test_multi_llm_processors(pdf_document, mocker):

contained_equations = pdf_document.contained_blocks((BlockTypes.Equation,))
print([equation.html for equation in contained_equations])
assert all(equation.html == description for equation in contained_equations)

@pytest.mark.filename("adversarial.pdf")
@pytest.mark.config({"page_range": [0]})
def test_llm_text_processor(pdf_document, mocker):
inline_math_block = pdf_document.contained_blocks((BlockTypes.TextInlineMath,))[0]
text_lines = inline_math_block.contained_blocks(pdf_document, (BlockTypes.Line,))
corrected_lines = ["<i>Text</i>"] * len(text_lines)

mock_cls = Mock()
mock_cls.return_value.generate_response.return_value = {"corrected_lines": corrected_lines}
mocker.patch("marker.processors.llm.GoogleModel", mock_cls)

config = {"use_llm": True, "google_api_key": "test"}
processor_lst = [LLMTextProcessor(config)]
processor = LLMSimpleBlockMetaProcessor(processor_lst, config)
processor(pdf_document)

contained_spans = text_lines[0].contained_blocks(pdf_document, (BlockTypes.Span,))
assert contained_spans[0].text == "Text\n" # Newline inserted at end of line
assert contained_spans[0].formats == ["italic"]
assert all(equation.html == description for equation in contained_equations)

0 comments on commit 42d8ab9

Please sign in to comment.