-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
FROM python:3.10-slim | ||
|
||
WORKDIR /app | ||
|
||
COPY . /app | ||
|
||
RUN pip install --no-cache-dir fastapi uvicorn pillow numpy opencv-python-headless onnxruntime paddleocr \ | ||
Check failure on line 7 in document_intelligence/equation_to_latex/Dockerfile GitHub Actions / verify / Lint Docker Files / Lint Docker Files
Check failure on line 7 in document_intelligence/equation_to_latex/Dockerfile GitHub Actions / verify / Lint Docker Files / Lint Docker Files
Check failure on line 7 in document_intelligence/equation_to_latex/Dockerfile GitHub Actions / Lint Docker Files / Lint Docker Files
|
||
&& pip install git+https://github.com/j2whiting/texteller.git | ||
|
||
EXPOSE 8001 | ||
|
||
# Run main.py when the container launches | ||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8001"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from fastapi import FastAPI | ||
from fastapi.responses import JSONResponse | ||
from pydantic import BaseModel | ||
from typing import Dict, List | ||
import base64 | ||
from texteller.mixed_inference_model import MixedInferenceModel | ||
|
||
|
||
app = FastAPI() | ||
|
||
model = MixedInferenceModel() | ||
|
||
class ImageData(BaseModel): | ||
images: Dict[int, List[str]] # {page_num: [base64_images]} | ||
|
||
@app.post("/predict_latex") | ||
async def predict_latex(data: ImageData): | ||
latex_results = {} | ||
|
||
for page_num, base64_images in data.images.items(): | ||
latex_results[page_num] = [] | ||
for img_base64 in base64_images: # TODO: use batching | ||
image_bytes = base64.b64decode(img_base64) | ||
latex_result = model.predict(image_bytes) | ||
latex_results[page_num].append(latex_result) | ||
|
||
return JSONResponse(content=latex_results) | ||
|
||
if __name__ == "__main__": | ||
import uvicorn | ||
uvicorn.run(app, host="0.0.0.0", port=8001) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
FROM python:3.10-slim | ||
|
||
WORKDIR /app | ||
|
||
COPY . /app | ||
|
||
RUN apt-get update && apt-get install -y libgl1-mesa-glx libglib2.0-0 && \ | ||
Check failure on line 7 in document_intelligence/nougat/Dockerfile GitHub Actions / verify / Lint Docker Files / Lint Docker Files
Check failure on line 7 in document_intelligence/nougat/Dockerfile GitHub Actions / verify / Lint Docker Files / Lint Docker Files
Check failure on line 7 in document_intelligence/nougat/Dockerfile GitHub Actions / verify / Lint Docker Files / Lint Docker Files
Check failure on line 7 in document_intelligence/nougat/Dockerfile GitHub Actions / verify / Lint Docker Files / Lint Docker Files
Check failure on line 7 in document_intelligence/nougat/Dockerfile GitHub Actions / Lint Docker Files / Lint Docker Files
Check failure on line 7 in document_intelligence/nougat/Dockerfile GitHub Actions / Lint Docker Files / Lint Docker Files
Check failure on line 7 in document_intelligence/nougat/Dockerfile GitHub Actions / Lint Docker Files / Lint Docker Files
Check failure on line 7 in document_intelligence/nougat/Dockerfile GitHub Actions / Lint Docker Files / Lint Docker Files
|
||
pip install --no-cache-dir fastapi uvicorn torch transformers pymupdf pillow | ||
|
||
EXPOSE 8000 | ||
|
||
# Run main.py when the container launches | ||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
import re | ||
import logging | ||
from collections import defaultdict | ||
from typing import List, Tuple, Dict | ||
from PIL import Image | ||
import fitz | ||
import torch | ||
from transformers import ( | ||
VisionEncoderDecoderModel, | ||
AutoProcessor, | ||
StoppingCriteria, | ||
StoppingCriteriaList | ||
) | ||
from fastapi import FastAPI, File, UploadFile | ||
from fastapi.responses import JSONResponse | ||
|
||
|
||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger("NougatEquation Task") | ||
|
||
class RunningVarTorch: | ||
def __init__(self, L=15, norm=False): | ||
self.values = None | ||
self.L = L | ||
self.norm = norm | ||
|
||
def push(self, x: torch.Tensor): | ||
assert x.dim() == 1 | ||
if self.values is None: | ||
self.values = x[:, None] | ||
elif self.values.shape[1] < self.L: | ||
self.values = torch.cat((self.values, x[:, None]), 1) | ||
else: | ||
self.values = torch.cat((self.values[:, 1:], x[:, None]), 1) | ||
|
||
def variance(self): | ||
if self.values is None: | ||
return | ||
if self.norm: | ||
return torch.var(self.values, 1) / self.values.shape[1] | ||
else: | ||
return torch.var(self.values, 1) | ||
|
||
|
||
class StoppingCriteriaScores(StoppingCriteria): | ||
def __init__(self, threshold: float = 0.015, window_size: int = 200): | ||
super().__init__() | ||
self.threshold = threshold | ||
self.vars = RunningVarTorch(norm=True) | ||
self.varvars = RunningVarTorch(L=window_size) | ||
self.stop_inds = defaultdict(int) | ||
self.stopped = defaultdict(bool) | ||
self.size = 0 | ||
self.window_size = window_size | ||
|
||
@torch.no_grad() | ||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): | ||
last_scores = scores[-1] | ||
self.vars.push(last_scores.max(1)[0].float().cpu()) | ||
self.varvars.push(self.vars.variance()) | ||
self.size += 1 | ||
if self.size < self.window_size: | ||
return False | ||
|
||
varvar = self.varvars.variance() | ||
for b in range(len(last_scores)): | ||
if varvar[b] < self.threshold: | ||
if self.stop_inds[b] > 0 and not self.stopped[b]: | ||
self.stopped[b] = self.stop_inds[b] >= self.size | ||
else: | ||
self.stop_inds[b] = int( | ||
min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095) | ||
) | ||
else: | ||
self.stop_inds[b] = 0 | ||
self.stopped[b] = False | ||
return all(self.stopped.values()) and len(self.stopped) > 0 | ||
|
||
|
||
class ImageQueue: | ||
def __init__(self, images: List[Image.Image], batch_size: int): | ||
self.images = images | ||
self.batch_size = batch_size | ||
self.index = 0 | ||
|
||
def next_batch(self) -> Tuple[List[Image.Image], int]: | ||
if self.index >= len(self.images): | ||
return None, self.index | ||
|
||
end_index = min(self.index + self.batch_size, len(self.images)) | ||
batch = self.images[self.index:end_index] | ||
|
||
self.index += self.batch_size | ||
return batch, self.index | ||
|
||
def __len__(self): | ||
return len(self.images) | ||
|
||
class NougatEquationModel: | ||
def __init__(self): | ||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
logging.info(f"Using device: {self.device}. Loading model on CPU until needed.") | ||
self.model = VisionEncoderDecoderModel.from_pretrained("facebook/nougat-small") | ||
self.processor = AutoProcessor.from_pretrained("facebook/nougat-small") | ||
|
||
def run_model(self, image_queue: ImageQueue): | ||
logger.info("Running model") | ||
logger.info(f"Using device: {self.device}") | ||
self.model.to(self.device) | ||
|
||
all_sequences = [] | ||
|
||
while True: | ||
batch, index = image_queue.next_batch() | ||
if batch is None: | ||
break | ||
|
||
pixel_values = self.processor(images=batch, return_tensors="pt").pixel_values | ||
outputs = self.model.generate( | ||
pixel_values.to(self.device), | ||
min_length=1, | ||
max_new_tokens=3584, | ||
bad_words_ids=[[self.processor.tokenizer.unk_token_id]], | ||
stopping_criteria=StoppingCriteriaList([StoppingCriteriaScores()]), | ||
return_dict_in_generate=True, | ||
output_scores=True, | ||
) | ||
sequences = self.processor.batch_decode(outputs['sequences'], skip_special_tokens=True) | ||
sequences = self.processor.post_process_generation(sequences, fix_markdown=False) | ||
|
||
all_sequences.extend(sequences) | ||
|
||
if index >= len(image_queue.images): | ||
break | ||
|
||
logger.info("Model run complete") | ||
logger.info("Returning model to CPU..") | ||
self.model.to("cpu") | ||
|
||
return all_sequences | ||
|
||
@staticmethod | ||
def extract_latex(text: str) -> List[str]: | ||
inline_pattern = re.compile(r'\\\(.*?\\\)') | ||
display_pattern = re.compile(r'\\\[.*?\\\]') | ||
dollar_pattern = re.compile(r'\$\$.*?\$\$', re.DOTALL) | ||
|
||
inline_latex = inline_pattern.findall(text) | ||
display_latex = display_pattern.findall(text) | ||
dollar_latex = dollar_pattern.findall(text) | ||
|
||
all_latex = inline_latex + display_latex + dollar_latex | ||
|
||
return list(set(all_latex)) | ||
|
||
def run_nougat(nougat_model: NougatEquationModel, images: List[Image.Image], batch_size: int) -> Dict[int, List[str]]: | ||
image_queue = ImageQueue(images, batch_size=batch_size) | ||
sequences = nougat_model.run_model(image_queue) | ||
latex_dict = {i: NougatEquationModel.extract_latex(sequence) for i, sequence in enumerate(sequences)} | ||
return latex_dict | ||
|
||
app = FastAPI() | ||
|
||
nougat_model = NougatEquationModel() | ||
|
||
@app.post("/process_pdf") | ||
async def process_pdf(file: UploadFile = File(...)): | ||
pdf_bytes = await file.read() | ||
|
||
doc = fitz.open(stream=pdf_bytes, filetype="pdf") | ||
images = [] | ||
for i in range(len(doc)): | ||
page = doc[i] | ||
image = page.get_pixmap() | ||
image = Image.frombytes("RGB", [image.width, image.height], image.samples) | ||
images.append(image) | ||
|
||
batch_size = 6 | ||
latex_dict = run_nougat(nougat_model, images, batch_size) | ||
|
||
return JSONResponse(content=latex_dict) | ||
|
||
if __name__ == "__main__": | ||
import uvicorn | ||
uvicorn.run(app, host="0.0.0.0", port=8000) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
FROM python:3.10-slim | ||
|
||
WORKDIR /app | ||
|
||
COPY . /app | ||
|
||
RUN pip install --no-cache-dir fastapi uvicorn pillow numpy opencv-python-headless onnxruntime paddleocr texteller | ||
Check failure on line 7 in document_intelligence/pdf_to_equation/Dockerfile GitHub Actions / verify / Lint Docker Files / Lint Docker Files
Check failure on line 7 in document_intelligence/pdf_to_equation/Dockerfile GitHub Actions / Lint Docker Files / Lint Docker Files
|
||
|
||
EXPOSE 8000 | ||
|
||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from fastapi import FastAPI, File, UploadFile | ||
from fastapi.responses import JSONResponse | ||
from pdf2image import convert_from_bytes | ||
from PIL import Image | ||
import numpy as np | ||
from cnstd import LayoutAnalyzer | ||
import base64 | ||
import io | ||
|
||
app = FastAPI() | ||
|
||
class PDFImageExtractor: | ||
def __init__(self, pdf_bytes): | ||
self.pdf_bytes = pdf_bytes | ||
self.images = self._pdf_to_images() | ||
|
||
def _pdf_to_images(self): | ||
pages = convert_from_bytes(self.pdf_bytes) | ||
return pages | ||
|
||
def get_images(self): | ||
return self.images | ||
|
||
class ImageAnalyzer: | ||
def __init__(self): | ||
self.analyzer = LayoutAnalyzer('mfd') | ||
|
||
def analyze_images(self, images): | ||
all_detections = [] | ||
for image in images: | ||
detections = self.analyze_image(image) | ||
all_detections.append(detections) | ||
return all_detections | ||
|
||
def analyze_image(self, image): | ||
return self.analyzer.analyze(image, resized_shape=1024) | ||
|
||
def get_cropped_images(self, images, all_detections, isolated_only=True, padding: int = None): | ||
cropped_images_dict = {} | ||
|
||
for page_number, detections in enumerate(all_detections): | ||
cropped_images = [] | ||
image_array = np.array(images[page_number]) | ||
for detection in detections: | ||
if isolated_only and detection['type'] != 'isolated': | ||
continue | ||
box = detection['box'] | ||
|
||
x_coords = box[:, 0] | ||
y_coords = box[:, 1] | ||
x_min = int(np.min(x_coords)) | ||
x_max = int(np.max(x_coords)) | ||
y_min = int(np.min(y_coords)) | ||
y_max = int(np.max(y_coords)) | ||
|
||
if padding: | ||
x_min = max(0, x_min - padding) | ||
x_max = min(image_array.shape[1], x_max + padding) | ||
y_min = max(0, y_min - padding) | ||
y_max = min(image_array.shape[0], y_max + padding) | ||
|
||
cropped_image = image_array[y_min:y_max, x_min:x_max] | ||
cropped_image_pil = Image.fromarray(cropped_image) | ||
|
||
cropped_images.append(cropped_image_pil) | ||
|
||
cropped_images_dict[page_number] = cropped_images | ||
|
||
return cropped_images_dict | ||
|
||
@app.post("/process_pdf") | ||
async def process_pdf(file: UploadFile = File(...), isolated_only: bool = True, padding: int = 10): | ||
pdf_bytes = await file.read() | ||
pdf_extractor = PDFImageExtractor(pdf_bytes) | ||
images = pdf_extractor.get_images() | ||
|
||
analyzer = ImageAnalyzer() | ||
detections = analyzer.analyze_images(images) | ||
cropped_images_dict = analyzer.get_cropped_images(images, detections, isolated_only, padding) | ||
|
||
# Convert PIL images to base64 to send in the response | ||
response_data = {} | ||
for page_number, cropped_images in cropped_images_dict.items(): | ||
response_data[page_number] = [] | ||
for image in cropped_images: | ||
buffered = io.BytesIO() | ||
image.save(buffered, format="PNG") | ||
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | ||
response_data[page_number].append(img_str) | ||
|
||
return JSONResponse(content=response_data) | ||
|
||
if __name__ == "__main__": | ||
import uvicorn | ||
uvicorn.run(app, host="0.0.0.0", port=8000) |