Skip to content

Commit

Permalink
Merge pull request #544 from VikParuchuri/vik/services
Browse files Browse the repository at this point in the history
Factor out llm services, enable local models
  • Loading branch information
VikParuchuri authored Feb 13, 2025
2 parents 4a385c2 + ed37c06 commit b4e8642
Show file tree
Hide file tree
Showing 23 changed files with 290 additions and 80 deletions.
18 changes: 15 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ See [below](#benchmarks) for detailed speed and accuracy benchmarks, and instruc

## Hybrid Mode

For the highest accuracy, pass the `--use_llm` flag to use an LLM alongside marker. This will do things like merge tables across pages, format tables properly, and extract values from forms. It uses `gemini-flash-2.0`, which is cheap and fast.
For the highest accuracy, pass the `--use_llm` flag to use an LLM alongside marker. This will do things like merge tables across pages, handle inline math, format tables properly, and extract values from forms. It can use any Google model (`gemini-2.0-flash` by default), or any ollama model. See [below](#llm-services) for details.

Here is a table benchmark comparing marker, gemini flash alone, and marker with use_llm:

Expand All @@ -42,7 +42,7 @@ As you can see, the use_llm mode offers higher accuracy than marker or gemini al

I want marker to be as widely accessible as possible, while still funding my development/training costs. Research and personal usage is always okay, but there are some restrictions on commercial usage.

The weights for the models are licensed `cc-by-nc-sa-4.0`, but I will waive that for any organization under $5M USD in gross revenue in the most recent 12-month period AND under $5M in lifetime VC/angel funding raised. You also must not be competitive with the [Datalab API](https://www.datalab.to/). If you want to remove the GPL license requirements (dual-license) and/or use the weights commercially over the revenue limit, check out the options [here](https://www.datalab.to).
The weights for the models are licensed `cc-by-nc-sa-4.0`, but I will waive that for any organization under \$5M USD in gross revenue in the most recent 12-month period AND under $5M in lifetime VC/angel funding raised. You also must not be competitive with the [Datalab API](https://www.datalab.to/). If you want to remove the GPL license requirements (dual-license) and/or use the weights commercially over the revenue limit, check out the options [here](https://www.datalab.to).

# Hosted API

Expand Down Expand Up @@ -105,6 +105,8 @@ Options:
- `--languages TEXT`: Optionally specify which languages to use for OCR processing. Accepts a comma-separated list. Example: `--languages "en,fr,de"` for English, French, and German.
- `config --help`: List all available builders, processors, and converters, and their associated configuration. These values can be used to build a JSON configuration file for additional tweaking of marker defaults.
- `--converter_cls`: One of `marker.converters.pdf.PdfConverter` (default) or `marker.converters.table.TableConverter`. The `PdfConverter` will convert the whole PDF, the `TableConverter` will only extract and convert tables.
- `--llm_service`: Which llm service to use if `--use_llm` is passed. This defaults to `marker.services.gemini.GoogleGeminiService`.
- `--help`: see all of the flags that can be passed into marker. (it supports many more options then are listed above)

The list of supported languages for surya OCR is [here](https://github.com/VikParuchuri/surya/blob/master/surya/recognition/languages.py). If you don't need OCR, marker can work with any language.

Expand Down Expand Up @@ -146,7 +148,7 @@ text, _, images = text_from_rendered(rendered)

### Custom configuration

You can pass configuration using the `ConfigParser`:
You can pass configuration using the `ConfigParser`. To see all available options, do `marker_single --help`.

```python
from marker.converters.pdf import PdfConverter
Expand Down Expand Up @@ -310,6 +312,16 @@ All output formats will return a metadata dictionary, with the following fields:
}
```

# LLM Services

When running with the `--use_llm` flag, you have a choice of services you can use:

- `Gemini` - this will use the Gemini developer API by default. You'll need to pass `--gemini_api_key` to configuration.
- `Google Vertex` - this will use vertex, which can be more reliable. You'll need to pass `--vertex_project_id`. To use it, set `--llm_service=marker.services.vertex.GoogleVertexService`.
- `Ollama` - this will use local models. You can configure `--ollama_base_url` and `--ollama_model`. To use it, set `--llm_service=marker.services.ollama.OllamaService`.

These services may have additional optional configuration as well - you can see it by viewing the classes.

# Internals

Marker is easy to extend. The core units of marker are:
Expand Down
10 changes: 5 additions & 5 deletions marker/builders/llm_layout.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Annotated
from typing import Annotated, Type

from surya.layout import LayoutPredictor
from tqdm import tqdm
from pydantic import BaseModel

from marker.builders.layout import LayoutBuilder
from marker.services.google import GoogleModel
from marker.services import BaseService
from marker.providers.pdf import PdfProvider
from marker.schema import BlockTypes
from marker.schema.blocks import Block
Expand Down Expand Up @@ -97,10 +97,10 @@ class LLMLayoutBuilder(LayoutBuilder):
Respond only with one of `Figure`, `Picture`, `ComplexRegion`, `Table`, or `Form`.
"""

def __init__(self, layout_model: LayoutPredictor, config=None):
def __init__(self, layout_model: LayoutPredictor, llm_service: BaseService, config=None):
super().__init__(layout_model, config)

self.model = GoogleModel(self.google_api_key, self.model_name)
self.llm_service = llm_service

def __call__(self, document: Document, provider: PdfProvider):
super().__call__(document, provider)
Expand Down Expand Up @@ -158,7 +158,7 @@ def process_block_complex_relabeling(self, document: Document, page: PageGroup,
def process_block_relabeling(self, document: Document, page: PageGroup, block: Block, prompt: str):
image = self.extract_image(document, block)

response = self.model.generate_response(
response = self.llm_service(
prompt,
image,
block,
Expand Down
3 changes: 2 additions & 1 deletion marker/config/crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from marker.processors import BaseProcessor
from marker.providers import BaseProvider
from marker.renderers import BaseRenderer
from marker.services import BaseService


class ConfigCrawler:
def __init__(self, base_classes=(BaseBuilder, BaseProcessor, BaseConverter, BaseProvider, BaseRenderer)):
def __init__(self, base_classes=(BaseBuilder, BaseProcessor, BaseConverter, BaseProvider, BaseRenderer, BaseService)):
self.base_classes = base_classes
self.class_config_map = {}

Expand Down
18 changes: 17 additions & 1 deletion marker/config/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def common_options(fn):
fn = click.option("--languages", type=str, default=None, help="Comma separated list of languages to use for OCR.")(fn)

# we put common options here
fn = click.option("--google_api_key", type=str, default=None, help="Google API key for using LLMs.")(fn)
fn = click.option("--use_llm", is_flag=True, default=False, help="Enable higher quality processing with LLMs.")(fn)
fn = click.option("--converter_cls", type=str, default=None, help="Converter class to use. Defaults to PDF converter.")(fn)
fn = click.option("--llm_service", type=str, default=None, help="LLM service to use - should be full import path, like marker.services.gemini.GoogleGeminiService")(fn)

# enum options
fn = click.option("--force_layout_block", type=click.Choice(choices=[t.name for t in BlockTypes]), default=None,)(fn)
Expand Down Expand Up @@ -74,8 +74,23 @@ def generate_config_dict(self) -> Dict[str, any]:
case _:
if k in crawler.attr_set:
config[k] = v

# Backward compatibility for google_api_key
if settings.GOOGLE_API_KEY:
config["gemini_api_key"] = settings.GOOGLE_API_KEY

return config

def get_llm_service(self):
# Only return an LLM service when use_llm is enabled
if not self.cli_options["use_llm"]:
return None

service_cls = self.cli_options["llm_service"]
if service_cls is None:
service_cls = "marker.services.gemini.GoogleGeminiService"
return service_cls

def get_renderer(self):
match self.cli_options["output_format"]:
case "json":
Expand Down Expand Up @@ -122,3 +137,4 @@ def get_output_folder(self, filepath: str):
def get_base_filename(self, filepath: str):
basename = os.path.basename(filepath)
return os.path.splitext(basename)[0]

4 changes: 3 additions & 1 deletion marker/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class BaseConverter:
def __init__(self, config: Optional[BaseModel | dict] = None):
assign_config(self, config)
self.config = config
self.llm_service = None

def __call__(self, *args, **kwargs):
raise NotImplementedError
Expand Down Expand Up @@ -52,7 +53,8 @@ def initialize_processors(self, processor_cls_lst: List[Type[BaseProcessor]]) ->

meta_processor = LLMSimpleBlockMetaProcessor(
processor_lst=simple_llm_processors,
config=self.config
llm_service=self.llm_service,
config=self.config,
)
other_processors.insert(insert_position, meta_processor)
return other_processors
20 changes: 19 additions & 1 deletion marker/converters/pdf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import os

from marker.services.gemini import GoogleGeminiService

os.environ["TOKENIZERS_PARALLELISM"] = "false" # disables a tokenizers warning

import inspect
Expand Down Expand Up @@ -86,7 +89,14 @@ class PdfConverter(BaseConverter):
DebugProcessor,
)

def __init__(self, artifact_dict: Dict[str, Any], processor_list: Optional[List[str]] = None, renderer: str | None = None, config=None):
def __init__(
self,
artifact_dict: Dict[str, Any],
processor_list: Optional[List[str]] = None,
renderer: str | None = None,
llm_service: str | None = None,
config=None
):
super().__init__(config)

for block_type, override_block_type in self.override_map.items():
Expand All @@ -102,6 +112,14 @@ def __init__(self, artifact_dict: Dict[str, Any], processor_list: Optional[List[
else:
renderer = MarkdownRenderer

if llm_service:
llm_service_cls = strings_to_classes([llm_service])[0]
llm_service = self.resolve_dependencies(llm_service_cls)

# Inject llm service into artifact_dict so it can be picked up by processors, etc.
artifact_dict["llm_service"] = llm_service
self.llm_service = llm_service

self.artifact_dict = artifact_dict
self.renderer = renderer

Expand Down
15 changes: 10 additions & 5 deletions marker/processors/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@

from marker.processors import BaseProcessor
from marker.schema import BlockTypes
from marker.services.google import GoogleModel
from marker.schema.blocks import Block
from marker.schema.document import Document
from marker.schema.groups import PageGroup
from marker.services import BaseService
from marker.settings import settings
from marker.util import assign_config


class PromptData(TypedDict):
Expand Down Expand Up @@ -67,14 +68,14 @@ class BaseLLMProcessor(BaseProcessor):
] = False
block_types = None

def __init__(self, config=None):
def __init__(self, llm_service: BaseService, config=None):
super().__init__(config)

self.model = None
self.llm_service = None
if not self.use_llm:
return

self.model = GoogleModel(self.google_api_key, self.model_name)
self.llm_service = llm_service

def extract_image(self, document: Document, image_block: Block, remove_blocks: Sequence[BlockTypes] | None = None) -> Image.Image:
return image_block.get_image(
Expand All @@ -90,7 +91,7 @@ class BaseLLMComplexBlockProcessor(BaseLLMProcessor):
A processor for using LLMs to convert blocks with more complex logic.
"""
def __call__(self, document: Document):
if not self.use_llm or self.model is None:
if not self.use_llm or self.llm_service is None:
return

try:
Expand Down Expand Up @@ -125,6 +126,10 @@ class BaseLLMSimpleBlockProcessor(BaseLLMProcessor):
A processor for using LLMs to convert single blocks.
"""

# Override init since we don't need an llmservice here
def __init__(self, config=None):
assign_config(self, config)

def __call__(self, result: dict, prompt_data: PromptData, document: Document):
try:
self.rewrite_block(result, prompt_data, document)
Expand Down
9 changes: 5 additions & 4 deletions marker/processors/llm/llm_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@

from marker.processors.llm import BaseLLMSimpleBlockProcessor, BaseLLMProcessor
from marker.schema.document import Document
from marker.services import BaseService


class LLMSimpleBlockMetaProcessor(BaseLLMProcessor):
"""
A wrapper for simple LLM processors, so they can all run in parallel.
"""
def __init__(self, processor_lst: List[BaseLLMSimpleBlockProcessor], config=None):
super().__init__(config)
def __init__(self, processor_lst: List[BaseLLMSimpleBlockProcessor], llm_service: BaseService, config=None):
super().__init__(llm_service, config)
self.processors = processor_lst

def __call__(self, document: Document):
if not self.use_llm or self.model is None:
if not self.use_llm or self.llm_service is None:
return

total = sum([len(processor.inference_blocks(document)) for processor in self.processors])
Expand Down Expand Up @@ -50,4 +51,4 @@ def __call__(self, document: Document):
pbar.close()

def get_response(self, prompt_data: Dict[str, Any]):
return self.model.generate_response(prompt_data["prompt"], prompt_data["image"], prompt_data["block"], prompt_data["schema"])
return self.llm_service(prompt_data["prompt"], prompt_data["image"], prompt_data["block"], prompt_data["schema"])
2 changes: 1 addition & 1 deletion marker/processors/llm/llm_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def process_rewriting(self, document: Document, page: PageGroup, block: Table):
def rewrite_single_chunk(self, page: PageGroup, block: Block, block_html: str, children: List[TableCell], image: Image.Image):
prompt = self.table_rewriting_prompt.replace("{block_html}", block_html)

response = self.model.generate_response(prompt, image, block, TableSchema)
response = self.llm_service(prompt, image, block, TableSchema)

if not response or "corrected_html" not in response:
block.update_metadata(llm_error_count=1)
Expand Down
2 changes: 1 addition & 1 deletion marker/processors/llm/llm_table_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def process_rewriting(self, document: Document, blocks: List[Block]):

prompt = self.table_merge_prompt.replace("{{table1}}", start_html).replace("{{table2}}", curr_html)

response = self.model.generate_response(
response = self.llm_service(
prompt,
[start_image, curr_image],
curr_block,
Expand Down
3 changes: 2 additions & 1 deletion marker/scripts/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def process_single_pdf(args):
config=config_parser.generate_config_dict(),
artifact_dict=model_refs,
processor_list=config_parser.get_processors(),
renderer=config_parser.get_renderer()
renderer=config_parser.get_renderer(),
llm_service=config_parser.get_llm_service()
)
rendered = converter(fpath)
out_folder = config_parser.get_output_folder(fpath)
Expand Down
3 changes: 2 additions & 1 deletion marker/scripts/convert_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def convert_single_cli(fpath: str, **kwargs):
config=config_parser.generate_config_dict(),
artifact_dict=models,
processor_list=config_parser.get_processors(),
renderer=config_parser.get_renderer()
renderer=config_parser.get_renderer(),
llm_service=config_parser.get_llm_service()
)
rendered = converter(fpath)
out_folder = config_parser.get_output_folder(fpath)
Expand Down
3 changes: 2 additions & 1 deletion marker/scripts/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ async def _convert_pdf(params: CommonParams):
config=config_dict,
artifact_dict=app_data["models"],
processor_list=config_parser.get_processors(),
renderer=config_parser.get_renderer()
renderer=config_parser.get_renderer(),
llm_service=config_parser.get_llm_service()
)
rendered = converter(params.filepath)
text, _, images = text_from_rendered(rendered)
Expand Down
3 changes: 2 additions & 1 deletion marker/scripts/streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def convert_pdf(fname: str, config_parser: ConfigParser) -> (str, Dict[str, Any]
config=config_dict,
artifact_dict=model_dict,
processor_list=config_parser.get_processors(),
renderer=config_parser.get_renderer()
renderer=config_parser.get_renderer(),
llm_service=config_parser.get_llm_service()
)
return converter(fname)

Expand Down
26 changes: 26 additions & 0 deletions marker/services/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Optional, List

import PIL
from pydantic import BaseModel

from marker.schema.blocks import Block
from marker.util import assign_config, verify_config_keys


class BaseService:
def __init__(self, config: Optional[BaseModel | dict] = None):
assign_config(self, config)

# Ensure we have all necessary fields filled out (API keys, etc.)
verify_config_keys(self)

def __call__(
self,
prompt: str,
image: PIL.Image.Image | List[PIL.Image.Image],
block: Block,
response_schema: type[BaseModel],
max_retries: int = 1,
timeout: int = 15
):
raise NotImplementedError
Loading

0 comments on commit b4e8642

Please sign in to comment.