diff --git a/benchmarks/table/inference.py b/benchmarks/table/inference.py index 0c6432d7..07e5e92c 100644 --- a/benchmarks/table/inference.py +++ b/benchmarks/table/inference.py @@ -27,6 +27,20 @@ def extract_tables(children: List[JSONBlockOutput]): tables.extend(extract_tables(child.children)) return tables +def fix_table_html(table_html: str) -> str: + marker_table_soup = BeautifulSoup(table_html, 'html.parser') + tbody = marker_table_soup.find('tbody') + if tbody: + tbody.unwrap() + for th_tag in marker_table_soup.find_all('th'): + th_tag.name = 'td' + for br_tag in marker_table_soup.find_all('br'): + br_tag.replace_with(marker_table_soup.new_string('')) + + marker_table_html = str(marker_table_soup) + marker_table_html = marker_table_html.replace("\n", " ") # Fintabnet uses spaces instead of newlines + return marker_table_html + def inference_tables(dataset, use_llm: bool, table_rec_batch_size: int | None, max_rows: int, use_gemini: bool): models = create_model_dict() @@ -154,18 +168,8 @@ def inference_tables(dataset, use_llm: bool, table_rec_batch_size: int | None, m # marker wraps the table in
which fintabnet data doesn't # Fintabnet doesn't use th tags, need to be replaced for fair comparison - marker_table_soup = BeautifulSoup(marker_table.html, 'html.parser') - tbody = marker_table_soup.find('tbody') - if tbody: - tbody.unwrap() - for th_tag in marker_table_soup.find_all('th'): - th_tag.name = 'td' - for br_tag in marker_table_soup.find_all('br'): - br_tag.replace_with(marker_table_soup.new_string('')) - - marker_table_html = str(marker_table_soup) - marker_table_html = marker_table_html.replace("\n", " ") # Fintabnet uses spaces instead of newlines - gemini_table_html = gemini_table.replace("\n", " ") # Fintabnet uses spaces instead of newlines + marker_table_html = fix_table_html(marker_table.html) + gemini_table_html = fix_table_html(gemini_table) results.append({ "marker_table": marker_table_html, diff --git a/marker/builders/document.py b/marker/builders/document.py index bbc688a6..e87ba001 100644 --- a/marker/builders/document.py +++ b/marker/builders/document.py @@ -2,6 +2,7 @@ from marker.builders import BaseBuilder from marker.builders.layout import LayoutBuilder +from marker.builders.line import LineBuilder from marker.builders.ocr import OcrBuilder from marker.providers.pdf import PdfProvider from marker.schema import BlockTypes @@ -27,9 +28,10 @@ class DocumentBuilder(BaseBuilder): "Disable OCR processing.", ] = False - def __call__(self, provider: PdfProvider, layout_builder: LayoutBuilder, ocr_builder: OcrBuilder): + def __call__(self, provider: PdfProvider, layout_builder: LayoutBuilder, line_builder: LineBuilder, ocr_builder: OcrBuilder): document = self.build_document(provider) layout_builder(document, provider) + line_builder(document, provider) if not self.disable_ocr: ocr_builder(document, provider) return document diff --git a/marker/builders/layout.py b/marker/builders/layout.py index 0eba225a..d5f0fdbc 100644 --- a/marker/builders/layout.py +++ b/marker/builders/layout.py @@ -1,13 +1,11 @@ -from typing import Annotated, List, Optional, Tuple +from typing import Annotated, List, Optional -import numpy as np from surya.layout import LayoutPredictor from surya.layout.schema import LayoutResult, LayoutBox -from surya.ocr_error import OCRErrorPredictor from surya.ocr_error.schema import OCRErrorDetectionResult from marker.builders import BaseBuilder -from marker.providers import ProviderOutput, ProviderPageLines +from marker.providers import ProviderPageLines from marker.providers.pdf import PdfProvider from marker.schema import BlockTypes from marker.schema.document import Document @@ -15,7 +13,6 @@ from marker.schema.polygon import PolygonBox from marker.schema.registry import get_block_class from marker.settings import settings -from marker.util import matrix_intersection_area class LayoutBuilder(BaseBuilder): @@ -27,33 +24,13 @@ class LayoutBuilder(BaseBuilder): "The batch size to use for the layout model.", "Default is None, which will use the default batch size for the model." ] = None - layout_coverage_min_lines: Annotated[ - int, - "The minimum number of PdfProvider lines that must be covered by the layout model", - "to consider the lines from the PdfProvider valid.", - ] = 1 - layout_coverage_threshold: Annotated[ - float, - "The minimum coverage ratio required for the layout model to consider", - "the lines from the PdfProvider valid.", - ] = .25 - document_ocr_threshold: Annotated[ - float, - "The minimum ratio of pages that must pass the layout coverage check", - "to avoid OCR.", - ] = .8 - excluded_for_coverage: Annotated[ - Tuple[BlockTypes], - "A list of block types to exclude from the layout coverage check.", - ] = (BlockTypes.Figure, BlockTypes.Picture, BlockTypes.Table, BlockTypes.FigureGroup, BlockTypes.TableGroup, BlockTypes.PictureGroup) force_layout_block: Annotated[ str, "Skip layout and force every page to be treated as a specific block type.", ] = None - def __init__(self, layout_model: LayoutPredictor, ocr_error_model: OCRErrorPredictor, config=None): + def __init__(self, layout_model: LayoutPredictor, config=None): self.layout_model = layout_model - self.ocr_error_model = ocr_error_model super().__init__(config) @@ -64,7 +41,6 @@ def __call__(self, document: Document, provider: PdfProvider): else: layout_results = self.surya_layout(document.pages) self.add_blocks_to_pages(document.pages, layout_results) - self.merge_blocks(document.pages, provider.page_lines) def get_batch_size(self): if self.layout_batch_size is not None: @@ -100,20 +76,6 @@ def surya_layout(self, pages: List[PageGroup]) -> List[LayoutResult]: ) return layout_results - def surya_ocr_error_detection(self, pages:List[PageGroup], provider_page_lines: ProviderPageLines) -> OCRErrorDetectionResult: - page_texts = [] - for document_page in pages: - page_text = '' - provider_lines = provider_page_lines.get(document_page.page_id, []) - page_text = '\n'.join(' '.join(s.text for s in line.spans) for line in provider_lines) - page_texts.append(page_text) - - ocr_error_detection_results = self.ocr_error_model( - page_texts, - batch_size=int(self.get_batch_size()) #TODO Better Multiplier - ) - return ocr_error_detection_results - def add_blocks_to_pages(self, pages: List[PageGroup], layout_results: List[LayoutResult]): for page, layout_result in zip(pages, layout_results): layout_page_size = PolygonBox.from_bbox(layout_result.image_bbox).size @@ -132,61 +94,4 @@ def add_blocks_to_pages(self, pages: List[PageGroup], layout_results: List[Layou # Ensure page has non-empty children if page.children is None: - page.children = [] - - def merge_blocks(self, document_pages: List[PageGroup], provider_page_lines: ProviderPageLines): - ocr_error_detection_labels = self.surya_ocr_error_detection(document_pages, provider_page_lines).labels - - good_pages = [] - for (document_page, ocr_error_detection_label) in zip(document_pages, ocr_error_detection_labels): - provider_lines = provider_page_lines.get(document_page.page_id, []) - good_pages.append( - bool(provider_lines) and - self.check_layout_coverage(document_page, provider_lines) and - (ocr_error_detection_label != "bad") - ) - - ocr_document = sum(good_pages) / len(good_pages) < self.document_ocr_threshold - for idx, document_page in enumerate(document_pages): - provider_lines = provider_page_lines.get(document_page.page_id, []) - needs_ocr = not good_pages[idx] - if needs_ocr and ocr_document: - document_page.text_extraction_method = "surya" - continue - document_page.merge_blocks(provider_lines, text_extraction_method="pdftext") - document_page.text_extraction_method = "pdftext" - - def check_layout_coverage( - self, - document_page: PageGroup, - provider_lines: List[ProviderOutput], - ): - covered_blocks = 0 - total_blocks = 0 - large_text_blocks = 0 - - layout_blocks = [document_page.get_block(block) for block in document_page.structure] - layout_blocks = [b for b in layout_blocks if b.block_type not in self.excluded_for_coverage] - - layout_bboxes = [block.polygon.bbox for block in layout_blocks] - provider_bboxes = [line.line.polygon.bbox for line in provider_lines] - - intersection_matrix = matrix_intersection_area(layout_bboxes, provider_bboxes) - - for idx, layout_block in enumerate(layout_blocks): - total_blocks += 1 - intersecting_lines = np.count_nonzero(intersection_matrix[idx] > 0) - - if intersecting_lines >= self.layout_coverage_min_lines: - covered_blocks += 1 - - if layout_block.polygon.intersection_pct(document_page.polygon) > 0.8 and layout_block.block_type == BlockTypes.Text: - large_text_blocks += 1 - - coverage_ratio = covered_blocks / total_blocks if total_blocks > 0 else 1 - text_okay = coverage_ratio > self.layout_coverage_threshold - - # Model will sometimes say there is a single block of text on the page when it is blank - if not text_okay and (total_blocks == 1 and large_text_blocks == 1): - text_okay = True - return text_okay + page.children = [] \ No newline at end of file diff --git a/marker/builders/line.py b/marker/builders/line.py new file mode 100644 index 00000000..293da018 --- /dev/null +++ b/marker/builders/line.py @@ -0,0 +1,490 @@ +from copy import deepcopy +from typing import Annotated, List, Optional, Tuple + +import numpy as np +from ftfy import fix_text +from PIL import Image, ImageDraw + +from surya.detection import DetectionPredictor, InlineDetectionPredictor, TextDetectionResult +from surya.ocr_error import OCRErrorPredictor + +from marker.builders import BaseBuilder +from marker.providers import ProviderOutput, ProviderPageLines +from marker.providers.pdf import PdfProvider +from marker.schema import BlockTypes +from marker.schema.document import Document +from marker.schema.groups.page import PageGroup +from marker.schema.polygon import PolygonBox +from marker.schema.registry import get_block_class +from marker.schema.text.line import Line +from marker.settings import settings +from marker.util import matrix_intersection_area + +class TextBox(PolygonBox): + math: bool = False + + def __hash__(self): + return hash(tuple(self.bbox)) + +class LineBuilder(BaseBuilder): + """ + A builder for detecting text lines, and inline math. Merges the detected lines with the lines from the provider + """ + detection_batch_size: Annotated[ + Optional[int], + "The batch size to use for the detection model.", + "Default is None, which will use the default batch size for the model." + ] = None + ocr_error_batch_size: Annotated[ + Optional[int], + "The batch size to use for the ocr error detection model.", + "Default is None, which will use the default batch size for the model." + ] = None + enable_table_ocr: Annotated[ + bool, + "Whether to skip OCR on tables. The TableProcessor will re-OCR them. Only enable if the TableProcessor is not running.", + ] = False + layout_coverage_min_lines: Annotated[ + int, + "The minimum number of PdfProvider lines that must be covered by the layout model", + "to consider the lines from the PdfProvider valid.", + ] = 1 + layout_coverage_threshold: Annotated[ + float, + "The minimum coverage ratio required for the layout model to consider", + "the lines from the PdfProvider valid.", + ] = .25 + span_inline_math_overlap_threshold: Annotated[ + float, + "The minimum overlap of a span with an inline math box to consider for removal" + ] = .5 + char_inline_math_overlap_threshold: Annotated[ + float, + "The minimum overlap of a character with an inline math box to consider for removal" + ] = .5 + line_inline_math_overlap_threshold: Annotated[ + float, + "The minimum overlap of a line with an inline math box to consider as a match" + ] = 0. + line_text_overlap_threshold: Annotated[ + float, + "The minimum overlap of an equation with a text line to consider as a match" + ] = .5 + inline_math_minimum_area: Annotated[ + float, + "The minimum area for an inline math block, in pixels." + ] = 20 + inline_math_line_vertical_merge_threshold: Annotated[ + int, + "The maximum pixel distance between y1s for two lines to be merged" + ] = 8 + excluded_for_coverage: Annotated[ + Tuple[BlockTypes], + "A list of block types to exclude from the layout coverage check.", + ] = (BlockTypes.Figure, BlockTypes.Picture, BlockTypes.Table, BlockTypes.FigureGroup, BlockTypes.TableGroup, BlockTypes.PictureGroup) + use_llm: Annotated[ + bool, + "Whether to use the LLM model for advanced processing." + ] = False + texify_inline_spans: Annotated[ + bool, + "Whether to run texify on inline math spans." + ] = False + ocr_remove_blocks: Tuple[BlockTypes, ...] = (BlockTypes.Table, BlockTypes.Form, BlockTypes.TableOfContents, BlockTypes.Equation) + + def __init__(self, detection_model: DetectionPredictor, inline_detection_model: InlineDetectionPredictor, ocr_error_model: OCRErrorPredictor, config=None): + super().__init__(config) + + self.detection_model = detection_model + self.inline_detection_model = inline_detection_model + self.ocr_error_model = ocr_error_model + + def __call__(self, document: Document, provider: PdfProvider): + # Disable Inline Detection for documents where layout model doesn't detect any equations + # Also disable if we won't use the inline detections (if we aren't using the LLM or texify) + do_inline_math_detection = document.contained_blocks([BlockTypes.Equation]) and (self.texify_inline_spans or self.use_llm) + provider_lines, ocr_lines = self.get_all_lines(document, provider, do_inline_math_detection) + self.merge_blocks(document, provider_lines, ocr_lines) + + def get_detection_batch_size(self): + if self.detection_batch_size is not None: + return self.detection_batch_size + elif settings.TORCH_DEVICE_MODEL == "cuda": + return 4 + return 4 + + def get_ocr_error_batch_size(self): + if self.ocr_error_batch_size is not None: + return self.ocr_error_batch_size + elif settings.TORCH_DEVICE_MODEL == "cuda": + return 4 + return 4 + + def get_detection_results(self, page_images: List[Image.Image], run_detection: List[bool], do_inline_math_detection: bool): + page_images = [p for p, good in zip(page_images, run_detection) if good] + page_detection_results = self.detection_model( + images=page_images, + batch_size=self.get_detection_batch_size() + ) + inline_detection_results = [None] * len(page_detection_results) + if do_inline_math_detection: + inline_detection_results = self.inline_detection_model( + images=page_images, + text_boxes=[[b.bbox for b in det_result.bboxes] for det_result in page_detection_results], + batch_size=self.get_detection_batch_size() + ) + + detection_results = [] + inline_results = [] + idx = 0 + for good in run_detection: + if good: + detection_results.append(page_detection_results[idx]) + inline_results.append(inline_detection_results[idx]) + idx += 1 + else: + detection_results.append(None) + inline_results.append(None) + assert idx == len(page_images) + + return detection_results, inline_results + + + def get_all_lines(self, document: Document, provider: PdfProvider, do_inline_math_detection: bool): + ocr_error_detection_results = self.ocr_error_detection(document.pages, provider.page_lines) + + boxes_to_ocr = {page.page_id: [] for page in document.pages} + page_lines = {page.page_id: [] for page in document.pages} + + LineClass: Line = get_block_class(BlockTypes.Line) + + layout_good = [] + for document_page, ocr_error_detection_label in zip(document.pages, ocr_error_detection_results.labels): + provider_lines: List[ProviderOutput] = provider.page_lines.get(document_page.page_id, []) + provider_lines_good = all([ + bool(provider), + ocr_error_detection_label != 'bad', + self.check_layout_coverage(document_page, provider_lines) + ]) + layout_good.append(provider_lines_good) + + run_detection = [not good or do_inline_math_detection for good in layout_good] + page_images = [page.get_image(highres=False, remove_blocks=self.ocr_remove_blocks) for page, good in zip(document.pages, run_detection) if good] + detection_results, inline_detection_results = self.get_detection_results(page_images, run_detection, do_inline_math_detection) + + assert len(detection_results) == len(inline_detection_results) == len(layout_good) == len(document.pages) + for document_page, detection_result, inline_detection_result, provider_lines_good in zip( + document.pages, + detection_results, + inline_detection_results, + layout_good + ): + provider_lines: List[ProviderOutput] = provider.page_lines.get(document_page.page_id, []) + page_size = provider.get_page_bbox(document_page.page_id).size + image_size = PolygonBox.from_bbox(detection_result.image_bbox).size if detection_result else page_size + + # Merge text and inline math detection results + merged_detection_boxes = self.determine_math_lines(text_result=detection_result, inline_result=inline_detection_result) + math_detection_boxes = [(i, box) for i, box in enumerate(merged_detection_boxes) if box.math] + nonmath_detection_boxes = [(i, box) for i, box in enumerate(merged_detection_boxes) if not box.math] + + if provider_lines_good: + # Merge inline math blocks into the provider lines, only persist new detected text lines which do not overlap with existing provider lines + # The missing lines are not from a table, so we can safely set this - The attribute for individual blocks is overridden by OCRBuilder + document_page.text_extraction_method = 'pdftext' + + # Add in the provider lines - merge ones that get broken by inline math + page_lines[document_page.page_id].extend( + self.merge_provider_lines_inline_math( + provider_lines, + [b for _,b in math_detection_boxes], + image_size, + page_size + ) + ) + else: + document_page.text_extraction_method = 'surya' + + # Sort lines properly + full_lines = nonmath_detection_boxes + math_detection_boxes + full_lines = sorted(full_lines, key=lambda x: x[0]) + full_lines = [b for _, b in full_lines] + + # Skip inline math merging if no provider lines are good; OCR all text lines and all inline math lines + boxes_to_ocr[document_page.page_id].extend(full_lines) + + # Dummy lines to merge into the document - Contains no spans, will be filled in later by OCRBuilder + ocr_lines = {document_page.page_id: [] for document_page in document.pages} + for page_id, page_ocr_boxes in boxes_to_ocr.items(): + page_size = provider.get_page_bbox(page_id).size + image_size = document.get_page(page_id).get_image(highres=False).size + for box_to_ocr in page_ocr_boxes: + line_polygon = PolygonBox(polygon=box_to_ocr.polygon).rescale(image_size, page_size) + format = ["math"] if box_to_ocr.math else None + ocr_lines[page_id].append( + ProviderOutput( + line=LineClass( + polygon=line_polygon, + page_id=page_id, + text_extraction_method='surya', + formats=format + ), + spans=[] + ) + ) + + return page_lines, ocr_lines + + def ocr_error_detection(self, pages:List[PageGroup], provider_page_lines: ProviderPageLines): + page_texts = [] + for document_page in pages: + provider_lines = provider_page_lines.get(document_page.page_id, []) + page_text = '\n'.join(' '.join(s.text for s in line.spans) for line in provider_lines) + page_texts.append(page_text) + + ocr_error_detection_results = self.ocr_error_model( + page_texts, + batch_size=int(self.get_ocr_error_batch_size()) + ) + return ocr_error_detection_results + + def check_layout_coverage( + self, + document_page: PageGroup, + provider_lines: List[ProviderOutput], + ): + covered_blocks = 0 + total_blocks = 0 + large_text_blocks = 0 + + layout_blocks = [document_page.get_block(block) for block in document_page.structure] + layout_blocks = [b for b in layout_blocks if b.block_type not in self.excluded_for_coverage] + + layout_bboxes = [block.polygon.bbox for block in layout_blocks] + provider_bboxes = [line.line.polygon.bbox for line in provider_lines] + + if len(layout_bboxes) == 0: + return True + + if len(provider_bboxes) == 0: + return False + + intersection_matrix = matrix_intersection_area(layout_bboxes, provider_bboxes) + + for idx, layout_block in enumerate(layout_blocks): + total_blocks += 1 + intersecting_lines = np.count_nonzero(intersection_matrix[idx] > 0) + + if intersecting_lines >= self.layout_coverage_min_lines: + covered_blocks += 1 + + if layout_block.polygon.intersection_pct(document_page.polygon) > 0.8 and layout_block.block_type == BlockTypes.Text: + large_text_blocks += 1 + + coverage_ratio = covered_blocks / total_blocks if total_blocks > 0 else 1 + text_okay = coverage_ratio >= self.layout_coverage_threshold + + # Model will sometimes say there is a single block of text on the page when it is blank + if not text_okay and (total_blocks == 1 and large_text_blocks == 1): + text_okay = True + return text_okay + + def merge_blocks(self, document: Document, page_provider_lines: ProviderPageLines, page_ocr_lines: ProviderPageLines): + for document_page in document.pages: + provider_lines = page_provider_lines[document_page.page_id] + ocr_lines = page_ocr_lines[document_page.page_id] + + # Only one or the other will have lines + merged_lines = provider_lines + ocr_lines + + # Text extraction method is overridden later for OCRed documents + document_page.merge_blocks(merged_lines, text_extraction_method='pdftext') + + + def determine_math_lines( + self, + text_result: TextDetectionResult, + inline_result: TextDetectionResult, + ) -> List[TextBox]: + """ + Marks lines as math if they contain inline math boxes. + """ + + if not text_result: + return [] + + text_boxes = [ + TextBox( + polygon=box.polygon + ) for box in text_result.bboxes + ] + + # Skip if no inline math was detected + if not inline_result: + return text_boxes + + inline_bboxes = [m.bbox for m in inline_result.bboxes] + text_bboxes = [t.bbox for t in text_boxes] + + if len(inline_bboxes) == 0: + return text_boxes + + if len(text_boxes) == 0: + return [] + + overlaps = matrix_intersection_area(inline_bboxes, text_bboxes) + + # Mark text boxes as math if they overlap with an inline math box + for i, inline_box in enumerate(inline_result.bboxes): + overlap_row = overlaps[i] + max_overlap_idx = np.argmax(overlap_row) + max_overlap_box = text_boxes[max_overlap_idx] + + max_overlap = np.max(overlap_row) / inline_box.area + + # Avoid small or nonoverlapping inline math regions + if max_overlap <= self.line_inline_math_overlap_threshold or inline_box.area < self.inline_math_minimum_area: + continue + + # Ignore vertical lines + if max_overlap_box.height > max_overlap_box.width * 2: + continue + + max_overlap_box.math = True + + return text_boxes + + # Add appropriate formats to math spans added by inline math detection + def add_math_span_format(self, provider_line): + if not provider_line.line.formats: + provider_line.line.formats = ["math"] + elif "math" not in provider_line.line.formats: + provider_line.line.formats.append("math") + + def merge_provider_lines_inline_math( + self, + provider_lines: List[ProviderOutput], + inline_math_lines: List[TextBox], + image_size, + page_size + ): + # When provider lines is empty or no inline math detected, return provider lines + if not provider_lines or not inline_math_lines: + return provider_lines + + horizontal_provider_lines = [ + (j, provider_line) for j, provider_line in enumerate(provider_lines) + if provider_line.line.polygon.height < provider_line.line.polygon.width * 3 # Multiply to account for small blocks inside equations, but filter out big vertical lines + ] + provider_line_boxes = [p.line.polygon.bbox for _, p in horizontal_provider_lines] + math_line_boxes = [PolygonBox(polygon=m.polygon).rescale(image_size, page_size).bbox for m in inline_math_lines] + + overlaps = matrix_intersection_area(math_line_boxes, provider_line_boxes) + + # Find potential merges + merge_lines = [] + for i in range(len(math_line_boxes)): + merge_line = [] + math_line_polygon = PolygonBox(polygon=inline_math_lines[i].polygon).rescale(image_size, page_size) + max_overlap = np.max(overlaps[i]) + if max_overlap <= self.line_inline_math_overlap_threshold: + continue + + best_overlap = np.argmax(overlaps[i]) + best_overlap_line = horizontal_provider_lines[best_overlap] + best_overlap_y1 = best_overlap_line[1].line.polygon.y_start + + nonzero_idxs = np.nonzero(overlaps[i] > self.line_inline_math_overlap_threshold)[0] + for idx in nonzero_idxs: + provider_idx, provider_line = horizontal_provider_lines[idx] + provider_line_y1 = provider_line.line.polygon.y_start + + should_merge_line = False + if abs(provider_line_y1 - best_overlap_y1) <= self.inline_math_line_vertical_merge_threshold: + should_merge_line = True + + line_overlaps = self.find_overlapping_math_chars(provider_line, math_line_polygon, remove_chars=not should_merge_line) + + # Do not merge if too far above/below (but remove characters) + if line_overlaps and should_merge_line: + # Add the index of the provider line to the merge line + merge_line.append(provider_idx) + + if len(merge_line) > 0: + merge_lines.append(merge_line) + + # Handle the merging + already_merged = set() + potential_merges = set([m for merge_line in merge_lines for m in merge_line]) + out_provider_lines = [(i, p) for i, p in enumerate(provider_lines) if i not in potential_merges] + for merge_section in merge_lines: + merge_section = [m for m in merge_section if m not in already_merged] + if len(merge_section) == 0: + continue + elif len(merge_section) == 1: + line_idx = merge_section[0] + merged_line = provider_lines[line_idx] + self.add_math_span_format(merged_line) + out_provider_lines.append((line_idx, merged_line)) + already_merged.add(merge_section[0]) + continue + + merge_section = sorted(merge_section) + merged_line = None + min_idx = min(merge_section) + for idx in merge_section: + provider_line = deepcopy(provider_lines[idx]) + if merged_line is None: + merged_line = provider_line + else: + # Combine the spans of the provider line with the merged line + merged_line = merged_line.merge(provider_line) + self.add_math_span_format(merged_line) + already_merged.add(idx) # Prevent double merging + out_provider_lines.append((min_idx, merged_line)) + + # Sort to preserve original order + out_provider_lines = sorted(out_provider_lines, key=lambda x: x[0]) + out_provider_lines = [p for _, p in out_provider_lines] + return out_provider_lines + + def clear_line_text(self, provider_line): + for span in provider_line.spans: + span.text = "" + + def find_overlapping_math_chars(self, provider_line, math_line_polygon, remove_chars=False): + # Identify if a character in the provider line overlaps with the inline math line - meaning that the line can be treated as math + spans = provider_line.spans + math_overlaps = False + + # For providers which do not surface characters + if provider_line.chars is None: + for span in spans: + if span.polygon.intersection_pct(math_line_polygon) > self.span_inline_math_overlap_threshold: + math_overlaps = True + return math_overlaps + + # For providers which surface characters - find line overlap based on characters + assert len(spans) == len(provider_line.chars), "Number of spans and characters in provider line do not match" + for span, span_chars in zip(spans, provider_line.chars): + if len(span_chars) == 0: + continue + + char_intersections_areas = matrix_intersection_area([char.polygon.bbox for char in span_chars], [math_line_polygon.bbox]).max(axis=-1) + char_intersections = char_intersections_areas / np.array([char.polygon.area for char in span_chars]) + + new_span_chars = [] + span_overlaps = False + for char, intersection_pct in zip(span_chars, char_intersections): + if intersection_pct >= self.char_inline_math_overlap_threshold: + span_overlaps = True + else: + new_span_chars.append(char) + + # Remove stray characters that overlap with math lines + if span_overlaps and remove_chars: + span.text = fix_text(''.join(c.char for c in new_span_chars)) + + math_overlaps = math_overlaps or span_overlaps + + return math_overlaps \ No newline at end of file diff --git a/marker/builders/llm_layout.py b/marker/builders/llm_layout.py index 4a0ca134..4ba76994 100644 --- a/marker/builders/llm_layout.py +++ b/marker/builders/llm_layout.py @@ -2,7 +2,6 @@ from typing import Annotated from surya.layout import LayoutPredictor -from surya.ocr_error import OCRErrorPredictor from tqdm import tqdm from pydantic import BaseModel @@ -98,8 +97,8 @@ class LLMLayoutBuilder(LayoutBuilder): Respond only with one of `Figure`, `Picture`, `ComplexRegion`, `Table`, or `Form`. """ - def __init__(self, layout_model: LayoutPredictor, ocr_error_model: OCRErrorPredictor, config=None): - super().__init__(layout_model, ocr_error_model, config) + def __init__(self, layout_model: LayoutPredictor, config=None): + super().__init__(layout_model, config) self.model = GoogleModel(self.google_api_key, self.model_name) diff --git a/marker/builders/ocr.py b/marker/builders/ocr.py index 43308c5b..18451be8 100644 --- a/marker/builders/ocr.py +++ b/marker/builders/ocr.py @@ -1,21 +1,20 @@ +import copy from typing import Annotated, List, Optional from ftfy import fix_text -from surya.detection import DetectionPredictor from surya.recognition import RecognitionPredictor from marker.builders import BaseBuilder -from marker.providers import ProviderOutput, ProviderPageLines +from marker.providers import ProviderPageLines from marker.providers.pdf import PdfProvider from marker.schema import BlockTypes +from marker.schema.blocks import BlockId from marker.schema.document import Document -from marker.schema.polygon import PolygonBox +from marker.schema.groups import PageGroup from marker.schema.registry import get_block_class -from marker.schema.text.line import Line from marker.schema.text.span import Span from marker.settings import settings - class OcrBuilder(BaseBuilder): """ A builder for performing OCR on PDF pages and merging the results into the document. @@ -25,30 +24,21 @@ class OcrBuilder(BaseBuilder): "The batch size to use for the recognition model.", "Default is None, which will use the default batch size for the model." ] = None - detection_batch_size: Annotated[ - Optional[int], - "The batch size to use for the detection model.", - "Default is None, which will use the default batch size for the model." - ] = None languages: Annotated[ Optional[List[str]], "A list of languages to use for OCR.", "Default is None." ] = None - enable_table_ocr: Annotated[ - bool, - "Whether to skip OCR on tables. The TableProcessor will re-OCR them. Only enable if the TableProcessor is not running.", - ] = False - def __init__(self, detection_model: DetectionPredictor, recognition_model: RecognitionPredictor, config=None): + def __init__(self, recognition_model: RecognitionPredictor, config=None): super().__init__(config) - self.detection_model = detection_model self.recognition_model = recognition_model def __call__(self, document: Document, provider: PdfProvider): - page_lines = self.ocr_extraction(document, provider) - self.merge_blocks(document, page_lines) + pages_to_ocr = [page for page in document.pages if page.text_extraction_method == 'surya'] + images, line_boxes, line_ids = self.get_ocr_images_boxes_ids(document, pages_to_ocr, provider) + self.ocr_extraction(document, pages_to_ocr, provider, images, line_boxes, line_ids) def get_recognition_batch_size(self): if self.recognition_batch_size is not None: @@ -59,64 +49,61 @@ def get_recognition_batch_size(self): return 32 return 32 - def get_detection_batch_size(self): - if self.detection_batch_size is not None: - return self.detection_batch_size - elif settings.TORCH_DEVICE_MODEL == "cuda": - return 4 - return 4 - - def ocr_extraction(self, document: Document, provider: PdfProvider) -> ProviderPageLines: - page_list = [page for page in document.pages if page.text_extraction_method == "surya"] + def get_ocr_images_boxes_ids(self, document: Document, pages: List[PageGroup], provider: PdfProvider): + highres_images, highres_boxes, line_ids = [], [], [] + for document_page in pages: + page_highres_image = document_page.get_image(highres=True) + page_highres_boxes = [] + page_line_ids = [] + + page_size = provider.get_page_bbox(document_page.page_id).size + image_size = page_highres_image.size + for block in document_page.contained_blocks(document): + block_lines = block.contained_blocks(document, [BlockTypes.Line]) + block_detected_lines = [block_line for block_line in block_lines if block_line.text_extraction_method == 'surya'] + + block.text_extraction_method = 'surya' + for line in block_detected_lines: + line_polygon = copy.deepcopy(line.polygon) + page_highres_boxes.append(line_polygon.rescale(page_size, image_size).bbox) + page_line_ids.append(line.id) + + highres_images.append(page_highres_image) + highres_boxes.append(page_highres_boxes) + line_ids.append(page_line_ids) + + return highres_images, highres_boxes, line_ids + + def ocr_extraction(self, document: Document, pages: List[PageGroup], provider: PdfProvider, images: List[any], line_boxes: List[List[float]], line_ids: List[List[BlockId]]): + if sum(len(b) for b in line_boxes)==0: + return - # Remove tables because we re-OCR them later with the table processor recognition_results = self.recognition_model( - images=[page.get_image(highres=False, remove_tables=not self.enable_table_ocr) for page in page_list], - langs=[self.languages] * len(page_list), - det_predictor=self.detection_model, - detection_batch_size=int(self.get_detection_batch_size()), + images=images, + bboxes=line_boxes, + langs=[self.languages] * len(document.pages), recognition_batch_size=int(self.get_recognition_batch_size()), - highres_images=[page.get_image(highres=True, remove_tables=not self.enable_table_ocr) for page in page_list] + sort_lines=False ) - page_lines = {} - SpanClass: Span = get_block_class(BlockTypes.Span) - LineClass: Line = get_block_class(BlockTypes.Line) - - for page_id, recognition_result in zip((page.page_id for page in page_list), recognition_results): - page_lines.setdefault(page_id, []) - - page_size = provider.get_page_bbox(page_id).size - - for ocr_line_idx, ocr_line in enumerate(recognition_result.text_lines): - image_polygon = PolygonBox.from_bbox(recognition_result.image_bbox) - polygon = PolygonBox.from_bbox(ocr_line.bbox).rescale(image_polygon.size, page_size) - - line = LineClass( - polygon=polygon, - page_id=page_id, + for document_page, page_recognition_result, page_line_ids in zip(pages, recognition_results, line_ids): + for line_id, ocr_line in zip(page_line_ids, page_recognition_result.text_lines): + if not fix_text(ocr_line.text): + continue + + line = document_page.get_block(line_id) + assert line.structure is None + new_span = SpanClass( + text=fix_text(ocr_line.text) + '\n', + formats=['plain'], + page_id=document_page.page_id, + polygon=copy.deepcopy(line.polygon), + minimum_position=0, + maximum_position=0, + font='Unknown', + font_weight=0, + font_size=0, ) - spans = [ - SpanClass( - text=fix_text(ocr_line.text) + "\n", - formats=['plain'], - page_id=page_id, - polygon=polygon, - minimum_position=0, - maximum_position=0, - font='Unknown', - font_weight=0, - font_size=0, - ) - ] - - page_lines[page_id].append(ProviderOutput(line=line, spans=spans)) - - return page_lines - - def merge_blocks(self, document: Document, page_lines: ProviderPageLines): - ocred_pages = [page for page in document.pages if page.text_extraction_method == "surya"] - for document_page in ocred_pages: - lines = page_lines[document_page.page_id] - document_page.merge_blocks(lines, text_extraction_method="surya") + document_page.add_full_block(new_span) + line.add_structure(new_span) \ No newline at end of file diff --git a/marker/converters/pdf.py b/marker/converters/pdf.py index 99080062..849d8a71 100644 --- a/marker/converters/pdf.py +++ b/marker/converters/pdf.py @@ -12,6 +12,7 @@ from marker.builders.document import DocumentBuilder from marker.builders.layout import LayoutBuilder from marker.builders.llm_layout import LLMLayoutBuilder +from marker.builders.line import LineBuilder from marker.builders.ocr import OcrBuilder from marker.builders.structure import StructureBuilder from marker.converters import BaseConverter @@ -115,10 +116,12 @@ def __init__(self, artifact_dict: Dict[str, Any], processor_list: Optional[List[ def build_document(self, filepath: str): provider_cls = provider_from_filepath(filepath) layout_builder = self.resolve_dependencies(self.layout_builder_class) + line_builder = self.resolve_dependencies(LineBuilder) ocr_builder = self.resolve_dependencies(OcrBuilder) with provider_cls(filepath, self.config) as provider: - document = DocumentBuilder(self.config)(provider, layout_builder, ocr_builder) - StructureBuilder(self.config)(document) + document = DocumentBuilder(self.config)(provider, layout_builder, line_builder, ocr_builder) + structure_builder_cls = self.resolve_dependencies(StructureBuilder) + structure_builder_cls(document) for processor in self.processor_list: processor(document) diff --git a/marker/converters/table.py b/marker/converters/table.py index 40ef1b0e..d73e1cc6 100644 --- a/marker/converters/table.py +++ b/marker/converters/table.py @@ -2,6 +2,7 @@ from typing import Tuple, List from marker.builders.document import DocumentBuilder +from marker.builders.line import LineBuilder from marker.builders.ocr import OcrBuilder from marker.converters.pdf import PdfConverter from marker.processors import BaseProcessor @@ -28,11 +29,12 @@ class TableConverter(PdfConverter): def build_document(self, filepath: str): provider_cls = provider_from_filepath(filepath) layout_builder = self.resolve_dependencies(self.layout_builder_class) + line_builder = self.resolve_dependencies(LineBuilder) ocr_builder = self.resolve_dependencies(OcrBuilder) document_builder = DocumentBuilder(self.config) document_builder.disable_ocr = True with provider_cls(filepath, self.config) as provider: - document = document_builder(provider, layout_builder, ocr_builder) + document = document_builder(provider, layout_builder, line_builder, ocr_builder) for page in document.pages: page.structure = [p for p in page.structure if p.block_type in self.converter_block_types] diff --git a/marker/models.py b/marker/models.py index 80dc254c..0a6b62eb 100644 --- a/marker/models.py +++ b/marker/models.py @@ -1,7 +1,7 @@ import os os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for an op, which is not supported on MPS -from surya.detection import DetectionPredictor +from surya.detection import DetectionPredictor, InlineDetectionPredictor from surya.layout import LayoutPredictor from surya.ocr_error import OCRErrorPredictor from surya.recognition import RecognitionPredictor @@ -16,5 +16,6 @@ def create_model_dict(device=None, dtype=None) -> dict: "recognition_model": RecognitionPredictor(device=device, dtype=dtype), "table_rec_model": TableRecPredictor(device=device, dtype=dtype), "detection_model": DetectionPredictor(device=device, dtype=dtype), + "inline_detection_model": InlineDetectionPredictor(device=device, dtype=dtype), "ocr_error_model": OCRErrorPredictor(device=device, dtype=dtype) } \ No newline at end of file diff --git a/marker/processors/equation.py b/marker/processors/equation.py index 20ac0fb4..99cd1fef 100644 --- a/marker/processors/equation.py +++ b/marker/processors/equation.py @@ -1,9 +1,10 @@ from typing import Annotated, List, Optional, Tuple -from tqdm import tqdm from marker.models import TexifyPredictor from marker.processors import BaseProcessor +from marker.processors.util import add_math_spans_to_line from marker.schema import BlockTypes +from marker.schema.blocks import Equation from marker.schema.document import Document from marker.settings import settings @@ -33,6 +34,10 @@ class EquationProcessor(BaseProcessor): bool, "Whether to disable the tqdm progress bar.", ] = False + texify_inline_spans: Annotated[ + bool, + "Whether to run texify on inline math spans." + ] = False def __init__(self, texify_model: TexifyPredictor, config=None): super().__init__(config) @@ -43,7 +48,13 @@ def __call__(self, document: Document): equation_data = [] for page in document.pages: - for block in page.contained_blocks(document, self.block_types): + equation_blocks = page.contained_blocks(document, self.block_types) + math_blocks = [] + if self.texify_inline_spans: + math_blocks = page.contained_blocks(document, (BlockTypes.Line,)) + math_blocks = [m for m in math_blocks if m.formats and "math" in m.formats] + + for block in equation_blocks + math_blocks: image = block.get_image(document, highres=False).convert("RGB") raw_text = block.raw_text(document) token_count = self.get_total_texify_tokens(raw_text) @@ -51,7 +62,8 @@ def __call__(self, document: Document): equation_data.append({ "image": image, "block_id": block.id, - "token_count": token_count + "token_count": token_count, + "page": page }) if len(equation_data) == 0: @@ -67,7 +79,11 @@ def __call__(self, document: Document): continue block = document.get_block(equation_d["block_id"]) - block.html = prediction + if isinstance(block, Equation): + block.html = prediction + else: + block.structure = [] + add_math_spans_to_line(prediction, block, equation_d["page"]) def get_batch_size(self): if self.texify_batch_size is not None: diff --git a/marker/processors/llm/__init__.py b/marker/processors/llm/__init__.py index 3f3f64be..975c363d 100644 --- a/marker/processors/llm/__init__.py +++ b/marker/processors/llm/__init__.py @@ -1,12 +1,13 @@ import traceback from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Annotated, TypedDict, List +from typing import Annotated, TypedDict, List, Sequence from pydantic import BaseModel from tqdm import tqdm from PIL import Image from marker.processors import BaseProcessor +from marker.schema import BlockTypes from marker.services.google import GoogleModel from marker.schema.blocks import Block from marker.schema.document import Document @@ -20,6 +21,7 @@ class PromptData(TypedDict): block: Block schema: BaseModel page: PageGroup + additional_data: dict | None class BlockData(TypedDict): @@ -74,8 +76,13 @@ def __init__(self, config=None): self.model = GoogleModel(self.google_api_key, self.model_name) - def extract_image(self, document: Document, image_block: Block): - return image_block.get_image(document, highres=True, expansion=(self.image_expansion_ratio, self.image_expansion_ratio)) + def extract_image(self, document: Document, image_block: Block, remove_blocks: Sequence[BlockTypes] | None = None) -> Image.Image: + return image_block.get_image( + document, + highres=True, + expansion=(self.image_expansion_ratio, self.image_expansion_ratio), + remove_blocks=remove_blocks + ) class BaseLLMComplexBlockProcessor(BaseLLMProcessor): diff --git a/marker/processors/llm/llm_handwriting.py b/marker/processors/llm/llm_handwriting.py index 19721da4..d169a29c 100644 --- a/marker/processors/llm/llm_handwriting.py +++ b/marker/processors/llm/llm_handwriting.py @@ -39,6 +39,7 @@ def inference_blocks(self, document: Document) -> List[BlockData]: for block_data in blocks: raw_text = block_data["block"].raw_text(document) block = block_data["block"] + # Don't process text blocks that contain lines already if block.block_type == BlockTypes.Text: lines = block.contained_blocks(document, (BlockTypes.Line,)) diff --git a/marker/processors/llm/llm_text.py b/marker/processors/llm/llm_text.py index 9a17455e..e49dff96 100644 --- a/marker/processors/llm/llm_text.py +++ b/marker/processors/llm/llm_text.py @@ -1,18 +1,26 @@ import json -from typing import List, Tuple +from typing import List, Tuple, Annotated from pydantic import BaseModel +from PIL import Image -from marker.processors.llm import BaseLLMSimpleBlockProcessor, PromptData -from bs4 import BeautifulSoup +from marker.processors.llm import BaseLLMSimpleBlockProcessor, PromptData, BlockData + +from marker.processors.util import add_math_spans_to_line from marker.schema import BlockTypes from marker.schema.blocks import Block from marker.schema.document import Document -from marker.schema.registry import get_block_class +from marker.schema.text import Line class LLMTextProcessor(BaseLLMSimpleBlockProcessor): - block_types = (BlockTypes.TextInlineMath,) + math_line_batch_size: Annotated[ + int, + "The number of math lines to batch together.", + ] = 10 + + block_types = (BlockTypes.Line,) + image_remove_blocks = (BlockTypes.Equation,) text_math_rewriting_prompt = """You are a text correction expert specializing in accurately reproducing text from images. You will receive an image of a text block and a set of extracted lines corresponding to the text in the image. Your task is to correct any errors in the extracted lines, including math, formatting, and other inaccuracies, and output the corrected lines in a JSON format. @@ -28,8 +36,8 @@ class LLMTextProcessor(BaseLLMSimpleBlockProcessor): * Formatting: Maintain consistent formatting with the text block image, including spacing, indentation, and special characters. * Other inaccuracies: If the image is handwritten then you may correct any spelling errors, or other discrepancies. 5. Do not remove any formatting i.e bold, italics, math, superscripts, subscripts, etc from the extracted lines unless it is necessary to correct an error. -6. Ensure that inline math is properly with inline math tags. -7. The number of corrected lines in the output MUST equal the number of extracted lines provided in the input. Do not add or remove lines. +6. Ensure that inline math is properly formatted with inline math tags. +7. The number of corrected lines in the output MUST equal the number of extracted lines provided in the input. Do not add or remove lines. There are exactly {line_count} input lines. 8. Output the corrected lines in JSON format with a "lines" field, as shown in the example below. 9. You absolutely cannot remove any ... tags, those are extremely important for references and are coming directly from the document, you MUST always preserve them. @@ -70,101 +78,83 @@ class LLMTextProcessor(BaseLLMSimpleBlockProcessor): ``` """ + def inference_blocks(self, document: Document) -> List[List[BlockData]]: + blocks = [] + for page in document.pages: + for block in page.contained_blocks(document, self.block_types): + if block.formats and "math" in block.formats: + blocks.append({ + "page": page, + "block": block + }) + + out_blocks = [] + for i in range(0, len(blocks), self.math_line_batch_size): + batch = blocks[i:i + self.math_line_batch_size] + out_blocks.append(batch) + return out_blocks + def get_block_lines(self, block: Block, document: Document) -> Tuple[list, list]: text_lines = block.contained_blocks(document, (BlockTypes.Line,)) extracted_lines = [line.formatted_text(document) for line in text_lines] return text_lines, extracted_lines + def combine_images(self, images: List[Image.Image]): + widths, heights = zip(*(i.size for i in images)) + total_width = max(widths) + total_height = sum(heights) + 5 * len(images) + + new_im = Image.new('RGB', (total_width, total_height), (255, 255, 255)) + + y_offset = 0 + for im in images: + new_im.paste(im, (0, y_offset)) + y_offset += im.size[1] + 5 + + return new_im + def block_prompts(self, document: Document) -> List[PromptData]: prompt_data = [] for block_data in self.inference_blocks(document): - block = block_data["block"] - _, extracted_lines = self.get_block_lines(block, document) + blocks: List[Line] = [b["block"] for b in block_data] + pages = [b["page"] for b in block_data] + block_lines = [block.formatted_text(document) for block in blocks] + + prompt = ( + self.text_math_rewriting_prompt + .replace("{extracted_lines}",json.dumps({"extracted_lines": block_lines}, indent=2)) + .replace("{line_count}", str(len(block_lines))) + ) + images = [self.extract_image(document, block, remove_blocks=self.image_remove_blocks) for block in blocks] + image = self.combine_images(images) - prompt = self.text_math_rewriting_prompt.replace("{extracted_lines}", - json.dumps({"extracted_lines": extracted_lines}, indent=2)) - image = self.extract_image(document, block) prompt_data.append({ "prompt": prompt, "image": image, - "block": block, + "block": blocks[0], "schema": LLMTextSchema, - "page": block_data["page"] + "page": pages[0], + "additional_data": {"blocks": blocks, "pages": pages} }) return prompt_data def rewrite_block(self, response: dict, prompt_data: PromptData, document: Document): - block = prompt_data["block"] - page = prompt_data["page"] - SpanClass = get_block_class(BlockTypes.Span) + blocks = prompt_data["additional_data"]["blocks"] + pages = prompt_data["additional_data"]["pages"] - text_lines, extracted_lines = self.get_block_lines(block, document) if not response or "corrected_lines" not in response: - block.update_metadata(llm_error_count=1) + blocks[0].update_metadata(llm_error_count=1) return corrected_lines = response["corrected_lines"] - if not corrected_lines or len(corrected_lines) != len(extracted_lines): - block.update_metadata(llm_error_count=1) + if not corrected_lines or len(corrected_lines) != len(blocks): + blocks[0].update_metadata(llm_error_count=1) return - for text_line, corrected_text in zip(text_lines, corrected_lines): + for text_line, page, corrected_text in zip(blocks, pages, corrected_lines): text_line.structure = [] - corrected_spans = self.text_to_spans(corrected_text) - - for span_idx, span in enumerate(corrected_spans): - if span_idx == len(corrected_spans) - 1: - span['content'] += "\n" - - span_block = page.add_full_block( - SpanClass( - polygon=text_line.polygon, - text=span['content'], - font='Unknown', - font_weight=0, - font_size=0, - minimum_position=0, - maximum_position=0, - formats=[span['type']], - url=span.get('url'), - page_id=text_line.page_id, - text_extraction_method="gemini", - ) - ) - text_line.structure.append(span_block.id) - - @staticmethod - def text_to_spans(text): - soup = BeautifulSoup(text, 'html.parser') - - tag_types = { - 'b': 'bold', - 'i': 'italic', - 'math': 'math', - } - spans = [] - - for element in soup.descendants: - if not len(list(element.parents)) == 1: - continue - - url = element.attrs.get('href') if hasattr(element, 'attrs') else None - - if element.name in tag_types: - spans.append({ - 'type': tag_types[element.name], - 'content': element.get_text(), - 'url': url - }) - elif element.string: - spans.append({ - 'type': 'plain', - 'content': element.string, - 'url': url - }) - - return spans + add_math_spans_to_line(corrected_text, text_line, page) class LLMTextSchema(BaseModel): corrected_lines: List[str] \ No newline at end of file diff --git a/marker/processors/util.py b/marker/processors/util.py new file mode 100644 index 00000000..859a51e6 --- /dev/null +++ b/marker/processors/util.py @@ -0,0 +1,64 @@ +from bs4 import BeautifulSoup + +from marker.schema import BlockTypes +from marker.schema.groups import PageGroup +from marker.schema.registry import get_block_class +from marker.schema.text import Line + + +def add_math_spans_to_line(corrected_text: str, text_line: Line, page: PageGroup): + SpanClass = get_block_class(BlockTypes.Span) + corrected_spans = text_to_spans(corrected_text) + + for span_idx, span in enumerate(corrected_spans): + if span_idx == len(corrected_spans) - 1: + span['content'] += "\n" + + span_block = page.add_full_block( + SpanClass( + polygon=text_line.polygon, + text=span['content'], + font='Unknown', + font_weight=0, + font_size=0, + minimum_position=0, + maximum_position=0, + formats=[span['type']], + url=span.get('url'), + page_id=text_line.page_id, + text_extraction_method="gemini", + ) + ) + text_line.structure.append(span_block.id) + + +def text_to_spans(text): + soup = BeautifulSoup(text, 'html.parser') + + tag_types = { + 'b': 'bold', + 'i': 'italic', + 'math': 'math', + } + spans = [] + + for element in soup.descendants: + if not len(list(element.parents)) == 1: + continue + + url = element.attrs.get('href') if hasattr(element, 'attrs') else None + + if element.name in tag_types: + spans.append({ + 'type': tag_types[element.name], + 'content': element.get_text(), + 'url': url + }) + elif element.string: + spans.append({ + 'type': 'plain', + 'content': element.string, + 'url': url + }) + + return spans \ No newline at end of file diff --git a/marker/providers/__init__.py b/marker/providers/__init__.py index 5230a410..448c70a5 100644 --- a/marker/providers/__init__.py +++ b/marker/providers/__init__.py @@ -1,3 +1,4 @@ +from copy import deepcopy from typing import List, Optional, Dict from PIL import Image @@ -10,15 +11,37 @@ from marker.schema.text.line import Line from marker.util import assign_config +class Char(BaseModel): + char: str + polygon: PolygonBox + char_idx: int class ProviderOutput(BaseModel): line: Line spans: List[Span] + chars: Optional[List[List[Char]]] = None @property def raw_text(self): return "".join(span.text for span in self.spans) + def __hash__(self): + return hash(tuple(self.line.polygon.bbox)) + + def merge(self, other: "ProviderOutput"): + new_output = deepcopy(self) + other_copy = deepcopy(other) + + new_output.spans.extend(other_copy.spans) + if new_output.chars is not None and other_copy.chars is not None: + new_output.chars.extend(other_copy.chars) + elif other_copy.chars is not None: + new_output.chars = other_copy.chars + + new_output.line.polygon = new_output.line.polygon.merge([other_copy.line.polygon]) + return new_output + + ProviderPageLines = Dict[int, List[ProviderOutput]] class BaseProvider: diff --git a/marker/providers/pdf.py b/marker/providers/pdf.py index 2b805f00..19f39686 100644 --- a/marker/providers/pdf.py +++ b/marker/providers/pdf.py @@ -11,7 +11,7 @@ from PIL import Image from pypdfium2 import PdfiumError -from marker.providers import BaseProvider, ProviderOutput, ProviderPageLines +from marker.providers import BaseProvider, ProviderOutput, Char, ProviderPageLines from marker.providers.utils import alphanum_ratio from marker.schema import BlockTypes from marker.schema.polygon import PolygonBox @@ -171,7 +171,7 @@ def pdftext_extraction(self) -> ProviderPageLines: page_char_blocks = dictionary_output( self.filepath, page_range=self.page_range, - keep_chars=False, + keep_chars=True, workers=self.pdftext_workers, flatten_pdf=self.flatten_pdf, quote_loosebox=False, @@ -191,6 +191,7 @@ def pdftext_extraction(self) -> ProviderPageLines: for block in page["blocks"]: for line in block["lines"]: spans: List[Span] = [] + chars: List[List[Char]] = [] for span in line["spans"]: if not span["text"]: continue @@ -199,6 +200,7 @@ def pdftext_extraction(self) -> ProviderPageLines: font_weight = span["font"]["weight"] or 0 font_size = span["font"]["size"] or 0 polygon = PolygonBox.from_bbox(span["bbox"], ensure_nonzero_area=True) + span_chars = [Char(char=c['char'], polygon=PolygonBox.from_bbox(c['bbox'], ensure_nonzero_area=True), char_idx=c['char_idx']) for c in span["chars"]] spans.append( SpanClass( polygon=polygon, @@ -214,11 +216,14 @@ def pdftext_extraction(self) -> ProviderPageLines: url=span.get("url"), ) ) + chars.append(span_chars) polygon = PolygonBox.from_bbox(line["bbox"], ensure_nonzero_area=True) + assert len(spans) == len(chars) lines.append( ProviderOutput( line=LineClass(polygon=polygon, page_id=page_id), - spans=spans + spans=spans, + chars=chars ) ) if self.check_line_spans(lines): diff --git a/marker/schema/blocks/base.py b/marker/schema/blocks/base.py index 69952f70..5ff40d4b 100644 --- a/marker/schema/blocks/base.py +++ b/marker/schema/blocks/base.py @@ -101,11 +101,11 @@ def from_block(cls, block: Block) -> Block: block_attrs = block.model_dump(exclude=["id", "block_id", "block_type"]) return cls(**block_attrs) - def get_image(self, document: Document, highres: bool = False, expansion: Tuple[float, float] | None = None) -> Image.Image | None: + def get_image(self, document: Document, highres: bool = False, expansion: Tuple[float, float] | None = None, remove_blocks: Sequence[BlockTypes] | None = None) -> Image.Image | None: image = self.highres_image if highres else self.lowres_image if image is None: page = document.get_page(self.page_id) - page_image = page.highres_image if highres else page.lowres_image + page_image = page.get_image(highres=highres, remove_blocks=remove_blocks) # Scale to the image size bbox = self.polygon.rescale((page.polygon.width, page.polygon.height), page_image.size) diff --git a/marker/schema/groups/page.py b/marker/schema/groups/page.py index af0790c2..681a07cd 100644 --- a/marker/schema/groups/page.py +++ b/marker/schema/groups/page.py @@ -38,16 +38,16 @@ def add_child(self, block: Block): else: self.children.append(block) - def get_image(self, *args, highres: bool = False, remove_tables: bool = False, **kwargs): + def get_image(self, *args, highres: bool = False, remove_blocks: Sequence[BlockTypes] | None = None, **kwargs): image = self.highres_image if highres else self.lowres_image - # Avoid double OCR for tables - if remove_tables: + # Avoid double OCR for certain elements + if remove_blocks: image = image.copy() draw = ImageDraw.Draw(image) - table_blocks = [block for block in self.children if block.block_type in (BlockTypes.Table, BlockTypes.Form, BlockTypes.TableOfContents)] - for table_block in table_blocks: - poly = table_block.polygon.rescale(self.polygon.size, image.size).polygon + bad_blocks = [block for block in self.children if block.block_type in remove_blocks] + for bad_block in bad_blocks: + poly = bad_block.polygon.rescale(self.polygon.size, image.size).polygon poly = [(int(p[0]), int(p[1])) for p in poly] draw.polygon(poly, fill='white') diff --git a/marker/schema/polygon.py b/marker/schema/polygon.py index 25e9ed31..42bce5dd 100644 --- a/marker/schema/polygon.py +++ b/marker/schema/polygon.py @@ -90,6 +90,30 @@ def expand(self, x_margin: float, y_margin: float) -> PolygonBox: new_polygon.append([poly[0] - x_margin, poly[1] + y_margin]) return PolygonBox(polygon=new_polygon) + def expand_y2(self, y_margin: float) -> PolygonBox: + new_polygon = [] + y_margin = y_margin * self.height + for idx, poly in enumerate(self.polygon): + if idx == 2: + new_polygon.append([poly[0], poly[1] + y_margin]) + elif idx == 3: + new_polygon.append([poly[0], poly[1] + y_margin]) + else: + new_polygon.append(poly) + return PolygonBox(polygon=new_polygon) + + def expand_y1(self, y_margin: float) -> PolygonBox: + new_polygon = [] + y_margin = y_margin * self.height + for idx, poly in enumerate(self.polygon): + if idx == 0: + new_polygon.append([poly[0], poly[1] - y_margin]) + elif idx == 1: + new_polygon.append([poly[0], poly[1] - y_margin]) + else: + new_polygon.append(poly) + return PolygonBox(polygon=new_polygon) + def minimum_gap(self, other: PolygonBox): if self.intersection_pct(other) > 0: return 0 diff --git a/marker/schema/text/line.py b/marker/schema/text/line.py index 6285ee88..9e8a0141 100644 --- a/marker/schema/text/line.py +++ b/marker/schema/text/line.py @@ -1,5 +1,6 @@ import html import re +from typing import Literal, List import regex @@ -36,6 +37,7 @@ def strip_trailing_hyphens(line_text, next_line_text, line_html) -> str: class Line(Block): block_type: BlockTypes = BlockTypes.Line block_description: str = "A line of text." + formats: List[Literal["math"]] | None = None # Sometimes we want to set math format at the line level, not span def formatted_text(self, document): text = "" diff --git a/poetry.lock b/poetry.lock index 0234f789..45ac5263 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1092,18 +1092,19 @@ requests = ["requests (>=2.20.0,<3.0.0.dev0)"] [[package]] name = "google-genai" -version = "1.1.0" +version = "1.2.0" description = "GenAI Python SDK" optional = false python-versions = ">=3.9" files = [ - {file = "google_genai-1.1.0-py3-none-any.whl", hash = "sha256:c48ac44612ad6aadc0bf96b12fa4314756baa16382c890fff793bcb53e9a9cc8"}, + {file = "google_genai-1.2.0-py3-none-any.whl", hash = "sha256:609d61bee73f1a6ae5b47e9c7dd4b469d50318f050c5ceacf835b0f80f79d2d9"}, ] [package.dependencies] google-auth = ">=2.14.1,<3.0.0dev" pydantic = ">=2.0.0,<3.0.0dev" requests = ">=2.28.1,<3.0.0dev" +typing-extensions = ">=4.11.0,<5.0.0dev" websockets = ">=13.0,<15.0dev" [[package]] @@ -4555,13 +4556,13 @@ snowflake = ["snowflake-connector-python (>=3.3.0)", "snowflake-snowpark-python[ [[package]] name = "surya-ocr" -version = "0.11.0" +version = "0.11.1" description = "OCR, layout, reading order, and table recognition in 90+ languages" optional = false python-versions = "<4.0,>=3.10" files = [ - {file = "surya_ocr-0.11.0-py3-none-any.whl", hash = "sha256:2314a04d6aa2f362eefb14145b9d1b2c5b6568fb287ff8205cc0d580b9a304a3"}, - {file = "surya_ocr-0.11.0.tar.gz", hash = "sha256:c13475981929ad1a50e0151085815bbff183f9f328d2efba9b77c119e9ca754a"}, + {file = "surya_ocr-0.11.1-py3-none-any.whl", hash = "sha256:cdf7a40613d7109661999beb97db63355456b3119583f8850559bc20a4ac30e2"}, + {file = "surya_ocr-0.11.1.tar.gz", hash = "sha256:1de05f1b00d0a9c4c6e737b51d9192161f1dd48a3cf76437e56a33a390bd2d26"}, ] [package.dependencies] @@ -5450,4 +5451,4 @@ propcache = ">=0.2.0" [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "d98a730ed15cb2a34a91a60062f5d6faa7eec256b2c42e79d868e5f0c9874c94" +content-hash = "b99e8e69eb72201880291be15253a25827fd226a99ed229a52ff1e0840f8482a" diff --git a/pyproject.toml b/pyproject.toml index 9d4cedf4..432b1019 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ torch = "^2.5.1" tqdm = "^4.66.1" ftfy = "^6.1.1" rapidfuzz = "^3.8.1" -surya-ocr = "~0.11.0" +surya-ocr = "~0.11.1" regex = "^2024.4.28" pdftext = "~0.5.1" markdownify = "^0.13.1" diff --git a/tests/builders/test_blank_page.py b/tests/builders/test_blank_page.py index 18b067c1..ba2073eb 100644 --- a/tests/builders/test_blank_page.py +++ b/tests/builders/test_blank_page.py @@ -2,10 +2,12 @@ from marker.builders.document import DocumentBuilder from marker.builders.layout import LayoutBuilder +from marker.builders.line import LineBuilder -def test_blank_page(config, pdf_provider, layout_model, ocr_error_model, recognition_model, detection_model): - layout_builder = LayoutBuilder(layout_model, ocr_error_model, config) +def test_blank_page(config, pdf_provider, layout_model, ocr_error_model, recognition_model, detection_model, inline_detection_model): + layout_builder = LayoutBuilder(layout_model, config) + line_builder = LineBuilder(detection_model, inline_detection_model, ocr_error_model) builder = DocumentBuilder(config) document = builder.build_document(pdf_provider) @@ -13,10 +15,11 @@ def test_blank_page(config, pdf_provider, layout_model, ocr_error_model, recogni bboxes=[], image_bbox=p.polygon.bbox, ) for p in document.pages] - page_lines = {p.page_id: [] for p in document.pages} + provider_lines = {p.page_id: [] for p in document.pages} + ocr_lines = {p.page_id: [] for p in document.pages} layout_builder.add_blocks_to_pages(document.pages, layout_results) - layout_builder.merge_blocks(document.pages, page_lines) + line_builder.merge_blocks(document, provider_lines, ocr_lines) assert all([isinstance(p.children, list) for p in document.pages]) assert all([isinstance(p.structure, list) for p in document.pages]) \ No newline at end of file diff --git a/tests/builders/test_garbled_pdf.py b/tests/builders/test_garbled_pdf.py index cb62bf4b..00fecdbf 100644 --- a/tests/builders/test_garbled_pdf.py +++ b/tests/builders/test_garbled_pdf.py @@ -1,7 +1,7 @@ import pytest from marker.builders.document import DocumentBuilder -from marker.builders.layout import LayoutBuilder +from marker.builders.line import LineBuilder from marker.processors.table import TableProcessor from marker.schema import BlockTypes @@ -17,10 +17,6 @@ def test_garbled_pdf(pdf_document, detection_model, recognition_model, table_rec assert table_cell.block_type == BlockTypes.Line assert table_cell.structure[0] == "/page/0/Span/2" - span = pdf_document.pages[0].contained_blocks(pdf_document, (BlockTypes.Span,))[0] - assert span.block_type == BlockTypes.Span - assert len(span.text.strip()) == 0 - # We don't OCR in the initial pass, only with the TableProcessor processor = TableProcessor(detection_model, recognition_model, table_rec_model) processor(pdf_document) @@ -28,27 +24,30 @@ def test_garbled_pdf(pdf_document, detection_model, recognition_model, table_rec table = pdf_document.pages[0].contained_blocks(pdf_document, (BlockTypes.Table,))[0] assert "варіант" in table.raw_text(pdf_document) + table_cell = pdf_document.pages[0].get_block(table_block.structure[0]) + assert table_cell.block_type == BlockTypes.TableCell + @pytest.mark.filename("hindi_judgement.pdf") -@pytest.mark.config({"page_range": [2, 3]}) -def test_garbled_builder(config, pdf_provider, layout_model, ocr_error_model): - layout_builder = LayoutBuilder(layout_model, ocr_error_model, config) +@pytest.mark.config({"page_range": [2, 3], "disable_ocr": True}) +def test_garbled_builder(config, pdf_provider, detection_model, inline_detection_model, ocr_error_model): + line_builder = LineBuilder(detection_model, inline_detection_model, ocr_error_model, config) builder = DocumentBuilder(config) document = builder.build_document(pdf_provider) - bad_ocr_results = layout_builder.surya_ocr_error_detection(document.pages, pdf_provider.page_lines) + bad_ocr_results = line_builder.ocr_error_detection(document.pages, pdf_provider.page_lines) assert len(bad_ocr_results.labels) == 2 assert any([l == "bad" for l in bad_ocr_results.labels]) @pytest.mark.filename("adversarial.pdf") -@pytest.mark.config({"page_range": [2, 3]}) -def test_nongarbled_builder(config, pdf_provider, layout_model, ocr_error_model): - layout_builder = LayoutBuilder(layout_model, ocr_error_model, config) +@pytest.mark.config({"page_range": [2, 3], "disable_ocr": True}) +def test_nongarbled_builder(config, pdf_provider, detection_model, inline_detection_model, ocr_error_model): + line_builder = LineBuilder(detection_model, inline_detection_model, ocr_error_model, config) builder = DocumentBuilder(config) document = builder.build_document(pdf_provider) - bad_ocr_results = layout_builder.surya_ocr_error_detection(document.pages, pdf_provider.page_lines) + bad_ocr_results = line_builder.ocr_error_detection(document.pages, pdf_provider.page_lines) assert len(bad_ocr_results.labels) == 2 assert all([l == "good" for l in bad_ocr_results.labels]) diff --git a/tests/builders/test_ocr_pipeline.py b/tests/builders/test_ocr_pipeline.py index 06f94d88..f079517e 100644 --- a/tests/builders/test_ocr_pipeline.py +++ b/tests/builders/test_ocr_pipeline.py @@ -4,8 +4,7 @@ from marker.schema.text.line import Line -@pytest.mark.config({"force_ocr": True, "page_range": [0]}) -def test_ocr_pipeline(pdf_document): +def _ocr_pipeline_test(pdf_document): first_page = pdf_document.pages[0] assert first_page.structure[0] == '/page/0/SectionHeader/0' @@ -23,11 +22,20 @@ def test_ocr_pipeline(pdf_document): # Ensure we match all text lines up properly # Makes sure the OCR bbox is being scaled to the same scale as the layout boxes text_lines = first_page.contained_blocks(pdf_document, (BlockTypes.Line,)) - text_blocks = first_page.contained_blocks(pdf_document, (BlockTypes.Text,BlockTypes.TextInlineMath)) - assert len(text_lines) == 75 + text_blocks = first_page.contained_blocks(pdf_document, (BlockTypes.Text, BlockTypes.TextInlineMath)) + assert len(text_lines) == 71 # Ensure the bbox sizes match up max_line_position = max([line.polygon.y_end for line in text_lines]) max_block_position = max([block.polygon.y_end for block in text_blocks if block.source == "layout"]) assert max_line_position <= (max_block_position * 1.02) + +@pytest.mark.config({"force_ocr": True, "page_range": [0]}) +def test_ocr_pipeline(pdf_document): + _ocr_pipeline_test(pdf_document) + +@pytest.mark.config({"force_ocr": True, "page_range": [0], "use_llm": True}) +def test_ocr_with_inline_pipeline(pdf_document): + _ocr_pipeline_test(pdf_document) + diff --git a/tests/builders/test_pdf_links.py b/tests/builders/test_pdf_links.py index 300a7579..58a0db55 100644 --- a/tests/builders/test_pdf_links.py +++ b/tests/builders/test_pdf_links.py @@ -6,13 +6,22 @@ from marker.renderers.markdown import MarkdownOutput from marker.schema import BlockTypes from marker.schema.document import Document +from marker.util import classes_to_strings @pytest.mark.filename("arxiv_test.pdf") @pytest.mark.output_format("markdown") -def test_pdf_links(pdf_document: Document, pdf_converter: PdfConverter, temp_pdf): +def test_pdf_links(pdf_document: Document, config, renderer, model_dict, temp_pdf): first_page = pdf_document.pages[1] + processors = ["marker.processors.reference.ReferenceProcessor"] + pdf_converter = PdfConverter( + artifact_dict=model_dict, + processor_list=processors, + renderer=classes_to_strings([renderer])[0], + config=config + ) + for section_header_span in first_page.contained_blocks(pdf_document, (BlockTypes.Span,)): if "II." in section_header_span.text: assert section_header_span.url == "#page-1-0" diff --git a/tests/conftest.py b/tests/conftest.py index 7d7f70ef..d297bf21 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,7 @@ from marker.builders.document import DocumentBuilder from marker.builders.layout import LayoutBuilder +from marker.builders.line import LineBuilder from marker.builders.ocr import OcrBuilder from marker.converters.pdf import PdfConverter from marker.models import create_model_dict @@ -55,6 +56,9 @@ def table_rec_model(model_dict): def ocr_error_model(model_dict): yield model_dict["ocr_error_model"] +@pytest.fixture(scope="session") +def inline_detection_model(model_dict): + yield model_dict["inline_detection_model"] @pytest.fixture(scope="function") def config(request): @@ -67,17 +71,19 @@ def config(request): return config +@pytest.fixture(scope="session") +def pdf_dataset(): + return datasets.load_dataset("datalab-to/pdfs", split="train") @pytest.fixture(scope="function") -def temp_pdf(request): +def temp_pdf(request, pdf_dataset): filename_mark = request.node.get_closest_marker("filename") filename = filename_mark.args[0] if filename_mark else "adversarial.pdf" - dataset = datasets.load_dataset("datalab-to/pdfs", split="train") - idx = dataset['filename'].index(filename) + idx = pdf_dataset['filename'].index(filename) temp_pdf = tempfile.NamedTemporaryFile(suffix=".pdf") - temp_pdf.write(dataset['pdf'][idx]) + temp_pdf.write(pdf_dataset['pdf'][idx]) temp_pdf.flush() yield temp_pdf @@ -88,11 +94,12 @@ def pdf_provider(request, config, temp_pdf): @pytest.fixture(scope="function") -def pdf_document(request, config, pdf_provider, layout_model, ocr_error_model, recognition_model, detection_model): - layout_builder = LayoutBuilder(layout_model, ocr_error_model, config) - ocr_builder = OcrBuilder(detection_model, recognition_model, config) +def pdf_document(request, config, pdf_provider, layout_model, ocr_error_model, recognition_model, detection_model, inline_detection_model): + layout_builder = LayoutBuilder(layout_model, config) + line_builder = LineBuilder(detection_model, inline_detection_model, ocr_error_model, config) + ocr_builder = OcrBuilder(recognition_model, config) builder = DocumentBuilder(config) - document = builder(pdf_provider, layout_builder, ocr_builder) + document = builder(pdf_provider, layout_builder, line_builder, ocr_builder) yield document diff --git a/tests/converters/test_table_converter.py b/tests/converters/test_table_converter.py index f388d5a7..56caa726 100644 --- a/tests/converters/test_table_converter.py +++ b/tests/converters/test_table_converter.py @@ -25,6 +25,6 @@ def test_table_converter(config, model_dict, renderer, temp_pdf): @pytest.mark.output_format("markdown") @pytest.mark.config({"page_range": [5], "force_ocr": True}) -def test_table_converter(config, model_dict, renderer, temp_pdf): +def test_table_converter_ocr(config, model_dict, renderer, temp_pdf): _table_converter(config, model_dict, renderer, temp_pdf) diff --git a/tests/processors/test_inline_math.py b/tests/processors/test_inline_math.py new file mode 100644 index 00000000..306349d3 --- /dev/null +++ b/tests/processors/test_inline_math.py @@ -0,0 +1,48 @@ +from unittest.mock import Mock + +import pytest + +from marker.processors.llm.llm_meta import LLMSimpleBlockMetaProcessor +from marker.processors.llm.llm_text import LLMTextProcessor +from marker.schema import BlockTypes + + +@pytest.mark.filename("adversarial.pdf") +@pytest.mark.config({"page_range": [0], "use_llm": True}) +def test_llm_text_processor(pdf_document, mocker): + # Get all inline math lines + text_lines = pdf_document.contained_blocks((BlockTypes.Line,)) + text_lines = [line for line in text_lines if line.formats and "math" in line.formats] + assert len(text_lines) == 3 + corrected_lines = [""] * len(text_lines) + + mock_cls = Mock() + mock_cls.return_value.generate_response.return_value = {"corrected_lines": corrected_lines} + mocker.patch("marker.processors.llm.GoogleModel", mock_cls) + + config = {"use_llm": True, "google_api_key": "test"} + processor_lst = [LLMTextProcessor(config)] + processor = LLMSimpleBlockMetaProcessor(processor_lst, config) + processor(pdf_document) + + contained_spans = text_lines[0].contained_blocks(pdf_document, (BlockTypes.Span,)) + assert contained_spans[0].text == "Text\n" # Newline inserted at end of line + assert contained_spans[0].formats == ["math"] + + +@pytest.mark.filename("adversarial.pdf") +@pytest.mark.config({"page_range": [0]}) +def test_llm_text_processor_disabled(pdf_document): + # Get all inline math lines + text_lines = pdf_document.contained_blocks((BlockTypes.Line,)) + text_lines = [line for line in text_lines if line.formats and "math" in line.formats] + assert len(text_lines) == 0 + + +@pytest.mark.filename("adversarial.pdf") +@pytest.mark.config({"page_range": [0], "texify_inline_spans": True}) +def test_llm_text_processor_texify(pdf_document): + # Get all inline math lines + text_lines = pdf_document.contained_blocks((BlockTypes.Line,)) + text_lines = [line for line in text_lines if line.formats and "math" in line.formats] + assert len(text_lines) == 3 \ No newline at end of file diff --git a/tests/processors/test_llm_processors.py b/tests/processors/test_llm_processors.py index e7aacad1..d334ee3d 100644 --- a/tests/processors/test_llm_processors.py +++ b/tests/processors/test_llm_processors.py @@ -182,24 +182,4 @@ def test_multi_llm_processors(pdf_document, mocker): contained_equations = pdf_document.contained_blocks((BlockTypes.Equation,)) print([equation.html for equation in contained_equations]) - assert all(equation.html == description for equation in contained_equations) - -@pytest.mark.filename("adversarial.pdf") -@pytest.mark.config({"page_range": [0]}) -def test_llm_text_processor(pdf_document, mocker): - inline_math_block = pdf_document.contained_blocks((BlockTypes.TextInlineMath,))[0] - text_lines = inline_math_block.contained_blocks(pdf_document, (BlockTypes.Line,)) - corrected_lines = ["Text"] * len(text_lines) - - mock_cls = Mock() - mock_cls.return_value.generate_response.return_value = {"corrected_lines": corrected_lines} - mocker.patch("marker.processors.llm.GoogleModel", mock_cls) - - config = {"use_llm": True, "google_api_key": "test"} - processor_lst = [LLMTextProcessor(config)] - processor = LLMSimpleBlockMetaProcessor(processor_lst, config) - processor(pdf_document) - - contained_spans = text_lines[0].contained_blocks(pdf_document, (BlockTypes.Span,)) - assert contained_spans[0].text == "Text\n" # Newline inserted at end of line - assert contained_spans[0].formats == ["italic"] \ No newline at end of file + assert all(equation.html == description for equation in contained_equations) \ No newline at end of file