Skip to content

Commit

Permalink
Merge pull request #536 from VikParuchuri/vik/generator
Browse files Browse the repository at this point in the history
Batch together llm inference requests
  • Loading branch information
VikParuchuri authored Feb 11, 2025
2 parents 264ed41 + 31b2e2a commit 4f514c7
Show file tree
Hide file tree
Showing 17 changed files with 389 additions and 122 deletions.
2 changes: 1 addition & 1 deletion marker/builders/llm_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pydantic import BaseModel

from marker.builders.layout import LayoutBuilder
from marker.processors.llm import GoogleModel
from marker.services.google import GoogleModel
from marker.providers.pdf import PdfProvider
from marker.schema import BlockTypes
from marker.schema.blocks import Block
Expand Down
48 changes: 46 additions & 2 deletions marker/converters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import Optional
import inspect
from typing import Optional, List, Type

from pydantic import BaseModel

from marker.processors import BaseProcessor
from marker.processors.llm import BaseLLMSimpleBlockProcessor
from marker.processors.llm.llm_meta import LLMSimpleBlockMetaProcessor
from marker.util import assign_config


Expand All @@ -11,4 +15,44 @@ def __init__(self, config: Optional[BaseModel | dict] = None):
self.config = config

def __call__(self, *args, **kwargs):
raise NotImplementedError
raise NotImplementedError

def resolve_dependencies(self, cls):
init_signature = inspect.signature(cls.__init__)
parameters = init_signature.parameters

resolved_kwargs = {}
for param_name, param in parameters.items():
if param_name == 'self':
continue
elif param_name == 'config':
resolved_kwargs[param_name] = self.config
elif param.name in self.artifact_dict:
resolved_kwargs[param_name] = self.artifact_dict[param_name]
elif param.default != inspect.Parameter.empty:
resolved_kwargs[param_name] = param.default
else:
raise ValueError(f"Cannot resolve dependency for parameter: {param_name}")

return cls(**resolved_kwargs)

def initialize_processors(self, processor_cls_lst: List[Type[BaseProcessor]]) -> List[BaseProcessor]:
processors = []
for processor_cls in processor_cls_lst:
processors.append(self.resolve_dependencies(processor_cls))

simple_llm_processors = [p for p in processors if issubclass(type(p), BaseLLMSimpleBlockProcessor)]
other_processors = [p for p in processors if not issubclass(type(p), BaseLLMSimpleBlockProcessor)]

if not simple_llm_processors:
return processors

llm_positions = [i for i, p in enumerate(processors) if issubclass(type(p), BaseLLMSimpleBlockProcessor)]
insert_position = max(0, llm_positions[-1] - len(simple_llm_processors) + 1)

meta_processor = LLMSimpleBlockMetaProcessor(
processor_lst=simple_llm_processors,
config=self.config
)
other_processors.insert(insert_position, meta_processor)
return other_processors
26 changes: 4 additions & 22 deletions marker/converters/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,32 +102,15 @@ def __init__(self, artifact_dict: Dict[str, Any], processor_list: Optional[List[
renderer = MarkdownRenderer

self.artifact_dict = artifact_dict
self.processor_list = processor_list
self.renderer = renderer

processor_list = self.initialize_processors(processor_list)
self.processor_list = processor_list

self.layout_builder_class = LayoutBuilder
if self.use_llm:
self.layout_builder_class = LLMLayoutBuilder

def resolve_dependencies(self, cls):
init_signature = inspect.signature(cls.__init__)
parameters = init_signature.parameters

resolved_kwargs = {}
for param_name, param in parameters.items():
if param_name == 'self':
continue
elif param_name == 'config':
resolved_kwargs[param_name] = self.config
elif param.name in self.artifact_dict:
resolved_kwargs[param_name] = self.artifact_dict[param_name]
elif param.default != inspect.Parameter.empty:
resolved_kwargs[param_name] = param.default
else:
raise ValueError(f"Cannot resolve dependency for parameter: {param_name}")

return cls(**resolved_kwargs)

@cache
def build_document(self, filepath: str):
provider_cls = provider_from_filepath(filepath)
Expand All @@ -137,8 +120,7 @@ def build_document(self, filepath: str):
document = DocumentBuilder(self.config)(provider, layout_builder, ocr_builder)
StructureBuilder(self.config)(document)

for processor_cls in self.processor_list:
processor = self.resolve_dependencies(processor_cls)
for processor in self.processor_list:
processor(document)

return document
Expand Down
3 changes: 1 addition & 2 deletions marker/converters/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ def build_document(self, filepath: str):
for page in document.pages:
page.structure = [p for p in page.structure if p.block_type in self.converter_block_types]

for processor_cls in self.processor_list:
processor = self.resolve_dependencies(processor_cls)
for processor in self.processor_list:
processor(document)

return document
Expand Down
61 changes: 56 additions & 5 deletions marker/processors/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Annotated, Optional
from typing import Annotated, TypedDict, List

from pydantic import BaseModel
from tqdm import tqdm
from PIL import Image

from marker.processors import BaseProcessor
from marker.processors.llm.utils import GoogleModel
from marker.services.google import GoogleModel
from marker.schema.blocks import Block
from marker.schema.document import Document
from marker.schema.groups import PageGroup
from marker.settings import settings


class PromptData(TypedDict):
prompt: str
image: Image.Image
block: Block
schema: BaseModel
page: PageGroup


class BlockData(TypedDict):
page: PageGroup
block: Block


class BaseLLMProcessor(BaseProcessor):
"""
A processor for using LLMs to convert blocks.
Expand All @@ -35,7 +50,7 @@ class BaseLLMProcessor(BaseProcessor):
timeout: Annotated[
int,
"The timeout for requests to the Gemini model.",
] = 15
] = 20
image_expansion_ratio: Annotated[
float,
"The ratio to expand the image by when cropping.",
Expand All @@ -59,6 +74,14 @@ def __init__(self, config=None):

self.model = GoogleModel(self.google_api_key, self.model_name)

def extract_image(self, document: Document, image_block: Block):
return image_block.get_image(document, highres=True, expansion=(self.image_expansion_ratio, self.image_expansion_ratio))


class BaseLLMComplexBlockProcessor(BaseLLMProcessor):
"""
A processor for using LLMs to convert blocks with more complex logic.
"""
def __call__(self, document: Document):
if not self.use_llm or self.model is None:
return
Expand Down Expand Up @@ -89,5 +112,33 @@ def rewrite_blocks(self, document: Document):

pbar.close()

def extract_image(self, document: Document, image_block: Block):
return image_block.get_image(document, highres=True, expansion=(self.image_expansion_ratio, self.image_expansion_ratio))

class BaseLLMSimpleBlockProcessor(BaseLLMProcessor):
"""
A processor for using LLMs to convert single blocks.
"""

def __call__(self, result: dict, prompt_data: PromptData, document: Document):
try:
self.rewrite_block(result, prompt_data, document)
except Exception as e:
print(f"Error rewriting block in {self.__class__.__name__}: {e}")
traceback.print_exc()

def inference_blocks(self, document: Document) -> List[BlockData]:
blocks = []
for page in document.pages:
for block in page.contained_blocks(document, self.block_types):
blocks.append({
"page": page,
"block": block
})
return blocks

def block_prompts(self, document: Document) -> List[PromptData]:
raise NotImplementedError()

def rewrite_block(self, response: dict, prompt_data: PromptData, document: Document):
raise NotImplementedError()


30 changes: 21 additions & 9 deletions marker/processors/llm/llm_complex.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from typing import List

import markdown2
from pydantic import BaseModel

from marker.processors.llm import BaseLLMProcessor
from marker.processors.llm import PromptData, BaseLLMSimpleBlockProcessor

from marker.schema import BlockTypes
from marker.schema.blocks import Block
from marker.schema.document import Document
from marker.schema.groups.page import PageGroup


class LLMComplexRegionProcessor(BaseLLMProcessor):
class LLMComplexRegionProcessor(BaseLLMSimpleBlockProcessor):
block_types = (BlockTypes.ComplexRegion,)
complex_region_prompt = """You are a text correction expert specializing in accurately reproducing text from images.
You will receive an image of a text block and the text that can be extracted from the image.
Expand Down Expand Up @@ -50,12 +50,24 @@ class LLMComplexRegionProcessor(BaseLLMProcessor):
```
"""

def process_rewriting(self, document: Document, page: PageGroup, block: Block):
text = block.raw_text(document)
prompt = self.complex_region_prompt.replace("{extracted_text}", text)
image = self.extract_image(document, block)
def block_prompts(self, document: Document) -> List[PromptData]:
prompt_data = []
for block in self.inference_blocks(document):
text = block["block"].raw_text(document)
prompt = self.complex_region_prompt.replace("{extracted_text}", text)
image = self.extract_image(document, block["block"])
prompt_data.append({
"prompt": prompt,
"image": image,
"block": block["block"],
"schema": ComplexSchema,
"page": block["page"]
})
return prompt_data

response = self.model.generate_response(prompt, image, block, ComplexSchema)
def rewrite_block(self, response: dict, prompt_data: PromptData, document: Document):
block = prompt_data["block"]
text = block.raw_text(document)

if not response or "corrected_markdown" not in response:
block.update_metadata(llm_error_count=1)
Expand Down
47 changes: 35 additions & 12 deletions marker/processors/llm/llm_equation.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,19 @@
from pydantic import BaseModel

from marker.processors.llm import BaseLLMProcessor

from marker.processors.llm import BaseLLMSimpleBlockProcessor, PromptData, BlockData
from marker.schema import BlockTypes
from marker.schema.blocks import Equation
from marker.schema.document import Document
from marker.schema.groups.page import PageGroup

from typing import Annotated
from typing import Annotated, List


class LLMEquationProcessor(BaseLLMProcessor):
class LLMEquationProcessor(BaseLLMSimpleBlockProcessor):
block_types = (BlockTypes.Equation,)
min_equation_height: Annotated[
float,
"The minimum ratio between equation height and page height to consider for processing.",
] = 0.08
equation_image_expansion_ratio: Annotated[
image_expansion_ratio: Annotated[
float,
"The ratio to expand the image by when cropping.",
] = 0.05 # Equations sometimes get bboxes that are too tight
Expand Down Expand Up @@ -62,13 +59,39 @@ class LLMEquationProcessor(BaseLLMProcessor):
```
"""

def process_rewriting(self, document: Document, page: PageGroup, block: Equation):
text = block.html if block.html else block.raw_text(document)
prompt = self.equation_latex_prompt.replace("{equation}", text)
def inference_blocks(self, document: Document) -> List[BlockData]:
blocks = super().inference_blocks(document)
out_blocks = []
for block_data in blocks:
block = block_data["block"]
page = block_data["page"]
if block.polygon.height / page.polygon.height < self.min_equation_height:
continue
out_blocks.append(block_data)
return out_blocks

def block_prompts(self, document: Document) -> List[PromptData]:
prompt_data = []
for block_data in self.inference_blocks(document):
block = block_data["block"]
text = block.html if block.html else block.raw_text(document)
prompt = self.equation_latex_prompt.replace("{equation}", text)
image = self.extract_image(document, block)

image = self.extract_image(document, block)
prompt_data.append({
"prompt": prompt,
"image": image,
"block": block,
"schema": EquationSchema,
"page": block_data["page"]
})

response = self.model.generate_response(prompt, image, block, EquationSchema)
return prompt_data


def rewrite_block(self, response: dict, prompt_data: PromptData, document: Document):
block = prompt_data["block"]
text = block.html if block.html else block.raw_text(document)

if not response or "html_equation" not in response:
block.update_metadata(llm_error_count=1)
Expand Down
Loading

0 comments on commit 4f514c7

Please sign in to comment.