Skip to content

Commit

Permalink
Merge pull request #235 from VikParuchuri/dev
Browse files Browse the repository at this point in the history
Performance improvements
  • Loading branch information
iammosespaulr authored Oct 30, 2024
2 parents 8af02d0 + 015bc31 commit a8b34c4
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 57 deletions.
2 changes: 1 addition & 1 deletion detect_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def main():
start = time.time()
line_predictions = batch_text_detection(images, det_model, det_processor)

layout_predictions = batch_layout_detection(images, model, processor, line_predictions)
layout_predictions = batch_layout_detection(images, model, processor, line_predictions, include_maps=args.debug)
result_path = os.path.join(args.results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)
if args.debug:
Expand Down
2 changes: 1 addition & 1 deletion detect_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def main():
folder_name = os.path.basename(args.input_path).split(".")[0]

start = time.time()
predictions = batch_text_detection(images, model, processor)
predictions = batch_text_detection(images, model, processor, include_maps=args.debug)
result_path = os.path.join(args.results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)
end = time.time()
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "surya-ocr"
version = "0.6.12"
version = "0.6.13"
description = "OCR, layout, reading order, and table recognition in 90+ languages"
authors = ["Vik Paruchuri <vik.paruchuri@gmail.com>"]
readme = "README.md"
Expand Down
23 changes: 11 additions & 12 deletions surya/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
from surya.schema import TextDetectionResult
from surya.settings import settings
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures import ThreadPoolExecutor
import torch.nn.functional as F

from surya.util.parallel import FakeParallel
from surya.util.parallel import FakeExecutor


def get_batch_size():
Expand Down Expand Up @@ -107,10 +107,12 @@ def batch_detection(
yield preds, [orig_sizes[j] for j in batch_image_idxs]


def parallel_get_lines(preds, orig_sizes):
def parallel_get_lines(preds, orig_sizes, include_maps=False):
heatmap, affinity_map = preds
heat_img = Image.fromarray((heatmap * 255).astype(np.uint8))
aff_img = Image.fromarray((affinity_map * 255).astype(np.uint8))
heat_img, aff_img = None, None
if include_maps:
heat_img = Image.fromarray((heatmap * 255).astype(np.uint8))
aff_img = Image.fromarray((affinity_map * 255).astype(np.uint8))
affinity_size = list(reversed(affinity_map.shape))
heatmap_size = list(reversed(heatmap.shape))
bboxes = get_and_clean_boxes(heatmap, heatmap_size, orig_sizes)
Expand All @@ -126,19 +128,16 @@ def parallel_get_lines(preds, orig_sizes):
return result


def batch_text_detection(images: List, model, processor, batch_size=None) -> List[TextDetectionResult]:
def batch_text_detection(images: List, model, processor, batch_size=None, include_maps=False) -> List[TextDetectionResult]:
detection_generator = batch_detection(images, model, processor, batch_size=batch_size)

postprocessing_futures = []
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH

with ProcessPoolExecutor(
max_workers=max_workers,
) if parallelize else contextlib.nullcontext() as executor:
func = executor.submit if parallelize else FakeParallel
executor = ThreadPoolExecutor if parallelize else FakeExecutor
with executor(max_workers=max_workers) as e:
for preds, orig_sizes in detection_generator:
for pred, orig_size in zip(preds, orig_sizes):
postprocessing_futures.append(func(parallel_get_lines, pred, orig_size))
postprocessing_futures.append(e.submit(parallel_get_lines, pred, orig_size, include_maps))

return [future.result() for future in postprocessing_futures]
26 changes: 13 additions & 13 deletions surya/layout.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional
from PIL import Image
import numpy as np
Expand All @@ -9,7 +9,7 @@
from surya.postprocessing.heatmap import keep_largest_boxes, get_and_clean_boxes, get_detected_boxes
from surya.schema import LayoutResult, LayoutBox, TextDetectionResult
from surya.settings import settings
from surya.util.parallel import FakeParallel
from surya.util.parallel import FakeExecutor


def get_regions_from_detection_result(detection_result: TextDetectionResult, heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment, vertical_line_width=20) -> List[LayoutBox]:
Expand Down Expand Up @@ -167,7 +167,7 @@ def get_regions(heatmaps: List[np.ndarray], orig_size, id2label, segment_assignm
return bboxes


def parallel_get_regions(heatmaps: List[np.ndarray], orig_size, id2label, detection_results=None) -> LayoutResult:
def parallel_get_regions(heatmaps: List[np.ndarray], orig_size, id2label, detection_results=None, include_maps=False) -> LayoutResult:
logits = np.stack(heatmaps, axis=0)
segment_assignment = logits.argmax(axis=0)
if detection_results is not None:
Expand All @@ -176,39 +176,39 @@ def parallel_get_regions(heatmaps: List[np.ndarray], orig_size, id2label, detect
else:
bboxes = get_regions(heatmaps, orig_size, id2label, segment_assignment)

segmentation_img = Image.fromarray(segment_assignment.astype(np.uint8))
segmentation_img = None
if include_maps:
segmentation_img = Image.fromarray(segment_assignment.astype(np.uint8))

result = LayoutResult(
bboxes=bboxes,
segmentation_map=segmentation_img,
heatmaps=heatmaps,
heatmaps=heatmaps if include_maps else None,
image_bbox=[0, 0, orig_size[0], orig_size[1]]
)

return result


def batch_layout_detection(images: List, model, processor, detection_results: Optional[List[TextDetectionResult]] = None, batch_size=None) -> List[LayoutResult]:
def batch_layout_detection(images: List, model, processor, detection_results: Optional[List[TextDetectionResult]] = None, batch_size=None, include_maps=False) -> List[LayoutResult]:
layout_generator = batch_detection(images, model, processor, batch_size=batch_size)
id2label = model.config.id2label

max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH
postprocessing_futures = []
with ProcessPoolExecutor(
max_workers=max_workers,
) if parallelize else contextlib.nullcontext() as executor:
executor = ThreadPoolExecutor if parallelize else FakeExecutor
with executor(max_workers=max_workers) as e:
img_idx = 0
func = executor.submit if parallelize else FakeParallel

for preds, orig_sizes in layout_generator:
for pred, orig_size in zip(preds, orig_sizes):
future = func(
future = e.submit(
parallel_get_regions,
pred,
orig_size,
id2label,
detection_results[img_idx] if detection_results else None
detection_results[img_idx] if detection_results else None,
include_maps
)

postprocessing_futures.append(future)
Expand Down
39 changes: 16 additions & 23 deletions surya/postprocessing/heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ def detect_boxes(linemap, text_threshold, low_text):
ex, ey = min(img_w, x + w + niter + buffer), min(img_h, y + h + niter + buffer)

mask = (labels[sy:ey, sx:ex] == k)
selected_linemap = linemap[sy:ey, sx:ex][mask]
line_max = np.max(selected_linemap)
line_max = np.max(linemap[sy:ey, sx:ex][mask])

# thresholding
if line_max < text_threshold:
Expand All @@ -115,13 +114,13 @@ def detect_boxes(linemap, text_threshold, low_text):
segmap = mask.astype(np.uint8)

ksize = buffer + niter
kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(ksize, ksize))
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (ksize, ksize))
selected_segmap = cv2.dilate(segmap, kernel)

# make box
indices = np.nonzero(selected_segmap)
x_inds = indices[1] + sx
y_inds = indices[0] + sy
y_inds, x_inds = np.nonzero(selected_segmap)
x_inds += sx
y_inds += sy
np_contours = np.column_stack((x_inds, y_inds))
rectangle = cv2.minAreaRect(np_contours)
box = cv2.boxPoints(rectangle)
Expand All @@ -130,39 +129,36 @@ def detect_boxes(linemap, text_threshold, low_text):
w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2])
box_ratio = max(w, h) / (min(w, h) + 1e-5)
if abs(1 - box_ratio) <= 0.1:
l, r = min(np_contours[:, 0]), max(np_contours[:, 0])
t, b = min(np_contours[:, 1]), max(np_contours[:, 1])
l, r = np_contours[:, 0].min(), np_contours[:, 0].max()
t, b = np_contours[:, 1].min(), np_contours[:, 1].max()
box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32)

# make clock-wise order
startidx = box.sum(axis=1).argmin()
box = np.roll(box, 4-startidx, 0)
box = np.array(box)
box = np.roll(box, 4 - startidx, 0)

confidence = line_max
max_confidence = max(max_confidence, line_max)

confidences.append(confidence)
confidences.append(line_max)
det.append(box)

if max_confidence > 0:
confidences = [c / max_confidence for c in confidences]
return det, confidences


def get_detected_boxes(textmap, text_threshold=None, low_text=None) -> List[PolygonBox]:
def get_detected_boxes(textmap, text_threshold=None, low_text=None) -> List[PolygonBox]:
if text_threshold is None:
text_threshold = settings.DETECTOR_TEXT_THRESHOLD

if low_text is None:
low_text = settings.DETECTOR_BLANK_THRESHOLD

textmap = textmap.copy()
textmap = textmap.astype(np.float32)
if textmap.dtype != np.float32:
textmap = textmap.astype(np.float32)

boxes, confidences = detect_boxes(textmap, text_threshold, low_text)
# From point form to box form
boxes = [PolygonBox(polygon=box, confidence=confidence) for box, confidence in zip(boxes, confidences)]
return boxes
return [PolygonBox(polygon=box, confidence=confidence) for box, confidence in zip(boxes, confidences)]


def get_and_clean_boxes(textmap, processor_size, image_size, text_threshold=None, low_text=None) -> List[PolygonBox]:
Expand All @@ -175,8 +171,7 @@ def get_and_clean_boxes(textmap, processor_size, image_size, text_threshold=None
return bboxes



def draw_bboxes_on_image(bboxes, image, labels=None, label_font_size=10, color: str | list='red'):
def draw_bboxes_on_image(bboxes, image, labels=None, label_font_size=10, color: str | list = 'red'):
polys = []
for bb in bboxes:
# Clockwise polygon
Expand All @@ -191,7 +186,7 @@ def draw_bboxes_on_image(bboxes, image, labels=None, label_font_size=10, color:
return draw_polys_on_image(polys, image, labels, label_font_size=label_font_size, color=color)


def draw_polys_on_image(corners, image, labels=None, box_padding=-1, label_offset=1, label_font_size=10, color: str | list='red'):
def draw_polys_on_image(corners, image, labels=None, box_padding=-1, label_offset=1, label_font_size=10, color: str | list = 'red'):
draw = ImageDraw.Draw(image)
font_path = get_font_path()
label_font = ImageFont.truetype(font_path, label_font_size)
Expand Down Expand Up @@ -223,5 +218,3 @@ def draw_polys_on_image(corners, image, labels=None, box_padding=-1, label_offse
)

return image


7 changes: 4 additions & 3 deletions surya/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,15 @@ class OCRResult(BaseModel):
class TextDetectionResult(BaseModel):
bboxes: List[PolygonBox]
vertical_lines: List[ColumnLine]
heatmap: Any
affinity_map: Any
heatmap: Optional[Any]
affinity_map: Optional[Any]
image_bbox: List[float]


class LayoutResult(BaseModel):
bboxes: List[LayoutBox]
segmentation_map: Any
segmentation_map: Optional[Any]
heatmaps: Optional[Any]
image_bbox: List[float]


Expand Down
19 changes: 16 additions & 3 deletions surya/util/parallel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
class FakeParallel():
def __init__(self, func, *args):
self._result = func(*args)
class FakeFuture:
def __init__(self, func, *args, **kwargs):
self._result = func(*args, **kwargs)

def result(self):
return self._result

class FakeExecutor:
def __init__(self, **kwargs):
pass

def __enter__(self):
return self

def __exit__(self, *excinfo):
pass

def submit(self, fn, *args, **kwargs):
return FakeFuture(fn, *args, **kwargs)

0 comments on commit a8b34c4

Please sign in to comment.