Skip to content

Commit

Permalink
v1 endpoints for equation extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
j2whiting committed Sep 4, 2024
1 parent b8f9e01 commit 73f0e29
Show file tree
Hide file tree
Showing 9 changed files with 348 additions and 0 deletions.
Empty file.
13 changes: 13 additions & 0 deletions document_intelligence/equation_to_latex/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
FROM python:3.10-slim

WORKDIR /app

COPY . /app

RUN pip install --no-cache-dir fastapi uvicorn pillow numpy opencv-python-headless onnxruntime paddleocr \

Check failure on line 7 in document_intelligence/equation_to_latex/Dockerfile

View workflow job for this annotation

GitHub Actions / verify / Lint Docker Files / Lint Docker Files

DL3013 warning: Pin versions in pip. Instead of `pip install <package>` use `pip install <package>==<version>` or `pip install --requirement <requirements file>`

Check failure on line 7 in document_intelligence/equation_to_latex/Dockerfile

View workflow job for this annotation

GitHub Actions / verify / Lint Docker Files / Lint Docker Files

DL3042 warning: Avoid use of cache directory with pip. Use `pip install --no-cache-dir <package>`

Check failure on line 7 in document_intelligence/equation_to_latex/Dockerfile

View workflow job for this annotation

GitHub Actions / Lint Docker Files / Lint Docker Files

DL3013 warning: Pin versions in pip. Instead of `pip install <package>` use `pip install <package>==<version>` or `pip install --requirement <requirements file>`

Check failure on line 7 in document_intelligence/equation_to_latex/Dockerfile

View workflow job for this annotation

GitHub Actions / Lint Docker Files / Lint Docker Files

DL3042 warning: Avoid use of cache directory with pip. Use `pip install --no-cache-dir <package>`
&& pip install git+https://github.com/j2whiting/texteller.git

EXPOSE 8001

# Run main.py when the container launches
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8001"]
Empty file.
31 changes: 31 additions & 0 deletions document_intelligence/equation_to_latex/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from typing import Dict, List
import base64
from texteller.mixed_inference_model import MixedInferenceModel


app = FastAPI()

model = MixedInferenceModel()

class ImageData(BaseModel):
images: Dict[int, List[str]] # {page_num: [base64_images]}

@app.post("/predict_latex")
async def predict_latex(data: ImageData):
latex_results = {}

for page_num, base64_images in data.images.items():
latex_results[page_num] = []
for img_base64 in base64_images: # TODO: use batching
image_bytes = base64.b64decode(img_base64)
latex_result = model.predict(image_bytes)
latex_results[page_num].append(latex_result)

return JSONResponse(content=latex_results)

if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)
13 changes: 13 additions & 0 deletions document_intelligence/nougat/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
FROM python:3.10-slim

WORKDIR /app

COPY . /app

RUN apt-get update && apt-get install -y libgl1-mesa-glx libglib2.0-0 && \

Check failure on line 7 in document_intelligence/nougat/Dockerfile

View workflow job for this annotation

GitHub Actions / verify / Lint Docker Files / Lint Docker Files

DL3008 warning: Pin versions in apt get install. Instead of `apt-get install <package>` use `apt-get install <package>=<version>`

Check failure on line 7 in document_intelligence/nougat/Dockerfile

View workflow job for this annotation

GitHub Actions / verify / Lint Docker Files / Lint Docker Files

DL3009 info: Delete the apt-get lists after installing something

Check failure on line 7 in document_intelligence/nougat/Dockerfile

View workflow job for this annotation

GitHub Actions / verify / Lint Docker Files / Lint Docker Files

DL3015 info: Avoid additional packages by specifying `--no-install-recommends`

Check failure on line 7 in document_intelligence/nougat/Dockerfile

View workflow job for this annotation

GitHub Actions / verify / Lint Docker Files / Lint Docker Files

DL3013 warning: Pin versions in pip. Instead of `pip install <package>` use `pip install <package>==<version>` or `pip install --requirement <requirements file>`

Check failure on line 7 in document_intelligence/nougat/Dockerfile

View workflow job for this annotation

GitHub Actions / Lint Docker Files / Lint Docker Files

DL3008 warning: Pin versions in apt get install. Instead of `apt-get install <package>` use `apt-get install <package>=<version>`

Check failure on line 7 in document_intelligence/nougat/Dockerfile

View workflow job for this annotation

GitHub Actions / Lint Docker Files / Lint Docker Files

DL3009 info: Delete the apt-get lists after installing something

Check failure on line 7 in document_intelligence/nougat/Dockerfile

View workflow job for this annotation

GitHub Actions / Lint Docker Files / Lint Docker Files

DL3015 info: Avoid additional packages by specifying `--no-install-recommends`

Check failure on line 7 in document_intelligence/nougat/Dockerfile

View workflow job for this annotation

GitHub Actions / Lint Docker Files / Lint Docker Files

DL3013 warning: Pin versions in pip. Instead of `pip install <package>` use `pip install <package>==<version>` or `pip install --requirement <requirements file>`
pip install --no-cache-dir fastapi uvicorn torch transformers pymupdf pillow

EXPOSE 8000

# Run main.py when the container launches
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
185 changes: 185 additions & 0 deletions document_intelligence/nougat/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import re
import logging
from collections import defaultdict
from typing import List, Tuple, Dict
from PIL import Image
import fitz
import torch
from transformers import (
VisionEncoderDecoderModel,
AutoProcessor,
StoppingCriteria,
StoppingCriteriaList
)
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("NougatEquation Task")

class RunningVarTorch:
def __init__(self, L=15, norm=False):
self.values = None
self.L = L
self.norm = norm

def push(self, x: torch.Tensor):
assert x.dim() == 1
if self.values is None:
self.values = x[:, None]
elif self.values.shape[1] < self.L:
self.values = torch.cat((self.values, x[:, None]), 1)
else:
self.values = torch.cat((self.values[:, 1:], x[:, None]), 1)

def variance(self):
if self.values is None:
return
if self.norm:
return torch.var(self.values, 1) / self.values.shape[1]
else:
return torch.var(self.values, 1)


class StoppingCriteriaScores(StoppingCriteria):
def __init__(self, threshold: float = 0.015, window_size: int = 200):
super().__init__()
self.threshold = threshold
self.vars = RunningVarTorch(norm=True)
self.varvars = RunningVarTorch(L=window_size)
self.stop_inds = defaultdict(int)
self.stopped = defaultdict(bool)
self.size = 0
self.window_size = window_size

@torch.no_grad()
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
last_scores = scores[-1]
self.vars.push(last_scores.max(1)[0].float().cpu())
self.varvars.push(self.vars.variance())
self.size += 1
if self.size < self.window_size:
return False

varvar = self.varvars.variance()
for b in range(len(last_scores)):
if varvar[b] < self.threshold:
if self.stop_inds[b] > 0 and not self.stopped[b]:
self.stopped[b] = self.stop_inds[b] >= self.size
else:
self.stop_inds[b] = int(
min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095)
)
else:
self.stop_inds[b] = 0
self.stopped[b] = False
return all(self.stopped.values()) and len(self.stopped) > 0


class ImageQueue:
def __init__(self, images: List[Image.Image], batch_size: int):
self.images = images
self.batch_size = batch_size
self.index = 0

def next_batch(self) -> Tuple[List[Image.Image], int]:
if self.index >= len(self.images):
return None, self.index

end_index = min(self.index + self.batch_size, len(self.images))
batch = self.images[self.index:end_index]

self.index += self.batch_size
return batch, self.index

def __len__(self):
return len(self.images)

class NougatEquationModel:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {self.device}. Loading model on CPU until needed.")
self.model = VisionEncoderDecoderModel.from_pretrained("facebook/nougat-small")
self.processor = AutoProcessor.from_pretrained("facebook/nougat-small")

def run_model(self, image_queue: ImageQueue):
logger.info("Running model")
logger.info(f"Using device: {self.device}")
self.model.to(self.device)

all_sequences = []

while True:
batch, index = image_queue.next_batch()
if batch is None:
break

pixel_values = self.processor(images=batch, return_tensors="pt").pixel_values
outputs = self.model.generate(
pixel_values.to(self.device),
min_length=1,
max_new_tokens=3584,
bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
stopping_criteria=StoppingCriteriaList([StoppingCriteriaScores()]),
return_dict_in_generate=True,
output_scores=True,
)
sequences = self.processor.batch_decode(outputs['sequences'], skip_special_tokens=True)
sequences = self.processor.post_process_generation(sequences, fix_markdown=False)

all_sequences.extend(sequences)

if index >= len(image_queue.images):
break

logger.info("Model run complete")
logger.info("Returning model to CPU..")
self.model.to("cpu")

return all_sequences

@staticmethod
def extract_latex(text: str) -> List[str]:
inline_pattern = re.compile(r'\\\(.*?\\\)')
display_pattern = re.compile(r'\\\[.*?\\\]')
dollar_pattern = re.compile(r'\$\$.*?\$\$', re.DOTALL)

inline_latex = inline_pattern.findall(text)
display_latex = display_pattern.findall(text)
dollar_latex = dollar_pattern.findall(text)

all_latex = inline_latex + display_latex + dollar_latex

return list(set(all_latex))

def run_nougat(nougat_model: NougatEquationModel, images: List[Image.Image], batch_size: int) -> Dict[int, List[str]]:
image_queue = ImageQueue(images, batch_size=batch_size)
sequences = nougat_model.run_model(image_queue)
latex_dict = {i: NougatEquationModel.extract_latex(sequence) for i, sequence in enumerate(sequences)}
return latex_dict

app = FastAPI()

nougat_model = NougatEquationModel()

@app.post("/process_pdf")
async def process_pdf(file: UploadFile = File(...)):
pdf_bytes = await file.read()

doc = fitz.open(stream=pdf_bytes, filetype="pdf")
images = []
for i in range(len(doc)):
page = doc[i]
image = page.get_pixmap()
image = Image.frombytes("RGB", [image.width, image.height], image.samples)
images.append(image)

batch_size = 6
latex_dict = run_nougat(nougat_model, images, batch_size)

return JSONResponse(content=latex_dict)

if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
11 changes: 11 additions & 0 deletions document_intelligence/pdf_to_equation/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
FROM python:3.10-slim

WORKDIR /app

COPY . /app

RUN pip install --no-cache-dir fastapi uvicorn pillow numpy opencv-python-headless onnxruntime paddleocr texteller

Check failure on line 7 in document_intelligence/pdf_to_equation/Dockerfile

View workflow job for this annotation

GitHub Actions / verify / Lint Docker Files / Lint Docker Files

DL3013 warning: Pin versions in pip. Instead of `pip install <package>` use `pip install <package>==<version>` or `pip install --requirement <requirements file>`

Check failure on line 7 in document_intelligence/pdf_to_equation/Dockerfile

View workflow job for this annotation

GitHub Actions / Lint Docker Files / Lint Docker Files

DL3013 warning: Pin versions in pip. Instead of `pip install <package>` use `pip install <package>==<version>` or `pip install --requirement <requirements file>`

EXPOSE 8000

CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
Empty file.
95 changes: 95 additions & 0 deletions document_intelligence/pdf_to_equation/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from pdf2image import convert_from_bytes
from PIL import Image
import numpy as np
from cnstd import LayoutAnalyzer
import base64
import io

app = FastAPI()

class PDFImageExtractor:
def __init__(self, pdf_bytes):
self.pdf_bytes = pdf_bytes
self.images = self._pdf_to_images()

def _pdf_to_images(self):
pages = convert_from_bytes(self.pdf_bytes)
return pages

def get_images(self):
return self.images

class ImageAnalyzer:
def __init__(self):
self.analyzer = LayoutAnalyzer('mfd')

def analyze_images(self, images):
all_detections = []
for image in images:
detections = self.analyze_image(image)
all_detections.append(detections)
return all_detections

def analyze_image(self, image):
return self.analyzer.analyze(image, resized_shape=1024)

def get_cropped_images(self, images, all_detections, isolated_only=True, padding: int = None):
cropped_images_dict = {}

for page_number, detections in enumerate(all_detections):
cropped_images = []
image_array = np.array(images[page_number])
for detection in detections:
if isolated_only and detection['type'] != 'isolated':
continue
box = detection['box']

x_coords = box[:, 0]
y_coords = box[:, 1]
x_min = int(np.min(x_coords))
x_max = int(np.max(x_coords))
y_min = int(np.min(y_coords))
y_max = int(np.max(y_coords))

if padding:
x_min = max(0, x_min - padding)
x_max = min(image_array.shape[1], x_max + padding)
y_min = max(0, y_min - padding)
y_max = min(image_array.shape[0], y_max + padding)

cropped_image = image_array[y_min:y_max, x_min:x_max]
cropped_image_pil = Image.fromarray(cropped_image)

cropped_images.append(cropped_image_pil)

cropped_images_dict[page_number] = cropped_images

return cropped_images_dict

@app.post("/process_pdf")
async def process_pdf(file: UploadFile = File(...), isolated_only: bool = True, padding: int = 10):
pdf_bytes = await file.read()
pdf_extractor = PDFImageExtractor(pdf_bytes)
images = pdf_extractor.get_images()

analyzer = ImageAnalyzer()
detections = analyzer.analyze_images(images)
cropped_images_dict = analyzer.get_cropped_images(images, detections, isolated_only, padding)

# Convert PIL images to base64 to send in the response
response_data = {}
for page_number, cropped_images in cropped_images_dict.items():
response_data[page_number] = []
for image in cropped_images:
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
response_data[page_number].append(img_str)

return JSONResponse(content=response_data)

if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

0 comments on commit 73f0e29

Please sign in to comment.