diff --git a/pyproject.toml b/pyproject.toml index a1b69e3..18a9523 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "surya-ocr" -version = "0.6.8" +version = "0.6.9" description = "OCR, layout, reading order, and table recognition in 90+ languages" authors = ["Vik Paruchuri "] readme = "README.md" diff --git a/surya/detection.py b/surya/detection.py index bef9d06..b51a5e5 100644 --- a/surya/detection.py +++ b/surya/detection.py @@ -1,3 +1,5 @@ +import contextlib +import multiprocessing import threading from queue import Queue from typing import List, Tuple, Generator @@ -16,6 +18,8 @@ from concurrent.futures import ProcessPoolExecutor import torch.nn.functional as F +from surya.util.parallel import FakeParallel + def get_batch_size(): batch_size = settings.DETECTOR_BATCH_SIZE @@ -127,18 +131,52 @@ def parallel_get_lines(preds, orig_sizes): def batch_text_detection(images: List, model, processor, batch_size=None) -> List[TextDetectionResult]: detection_generator = batch_detection(images, model, processor, batch_size=batch_size) - results = [] + postprocessing_futures = [] max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images)) parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH - - if parallelize: - with ProcessPoolExecutor(max_workers=max_workers) as executor: - for preds, orig_sizes in detection_generator: - batch_results = list(executor.map(parallel_get_lines, preds, orig_sizes)) - results.extend(batch_results) - else: - for preds, orig_sizes in detection_generator: - for pred, orig_size in zip(preds, orig_sizes): - results.append(parallel_get_lines(pred, orig_size)) + batch_queue = Queue() + processing_error = threading.Event() + + def inference_producer(): + try: + for batch in detection_generator: + batch_queue.put(batch) + if processing_error.is_set(): + break + except Exception as e: + processing_error.set() + print("Error with batch detection", e) + finally: + batch_queue.put(None) # Signal end of batches + + def postprocessing_consumer(executor): + while not processing_error.is_set(): + batch = batch_queue.get() + if batch is None: + break + + try: + preds, orig_sizes = batch + func = executor.submit if parallelize else FakeParallel + for pred, orig_size in zip(preds, orig_sizes): + postprocessing_futures.append(func(parallel_get_lines, pred, orig_size)) + except Exception as e: + processing_error.set() + print("Error with postprocessing", e) + + # Start producer and consumer threads + producer = threading.Thread(target=inference_producer, daemon=True) + producer.start() + + with ProcessPoolExecutor( + max_workers=max_workers, + mp_context=multiprocessing.get_context("spawn") + ) if parallelize else contextlib.nullcontext() as executor: + consumer = threading.Thread(target=postprocessing_consumer, args=(executor,), daemon=True) + consumer.start() + producer.join() + consumer.join() + + results = [future.result() for future in postprocessing_futures] return results \ No newline at end of file diff --git a/surya/layout.py b/surya/layout.py index d488b97..eeb433d 100644 --- a/surya/layout.py +++ b/surya/layout.py @@ -1,3 +1,5 @@ +import contextlib +import multiprocessing import threading from collections import defaultdict from concurrent.futures import ProcessPoolExecutor @@ -10,6 +12,7 @@ from surya.postprocessing.heatmap import keep_largest_boxes, get_and_clean_boxes, get_detected_boxes from surya.schema import LayoutResult, LayoutBox, TextDetectionResult from surya.settings import settings +from surya.util.parallel import FakeParallel def get_regions_from_detection_result(detection_result: TextDetectionResult, heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment, vertical_line_width=20) -> List[LayoutBox]: @@ -192,40 +195,55 @@ def batch_layout_detection(images: List, model, processor, detection_results: Op layout_generator = batch_detection(images, model, processor, batch_size=batch_size) id2label = model.config.id2label - results = [] + postprocessing_futures = [] max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images)) parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH + batch_queue = Queue() + processing_error = threading.Event() + + def inference_producer(): + try: + for batch in layout_generator: + batch_queue.put(batch) + if processing_error.is_set(): + break + except Exception as e: + processing_error.set() + print("Error in layout detection producer", e) + finally: + batch_queue.put(None) # Signal end of batches + + def postprocessing_consumer(executor): + img_idx = 0 + while not processing_error.is_set(): + batch = batch_queue.get() + if batch is None: + break - if parallelize: - with ProcessPoolExecutor(max_workers=max_workers) as executor: - img_idx = 0 - for preds, orig_sizes in layout_generator: - futures = [] + try: + preds, orig_sizes = batch for pred, orig_size in zip(preds, orig_sizes): - future = executor.submit( - parallel_get_regions, - pred, - orig_size, - id2label, - detection_results[img_idx] if detection_results else None - ) - - futures.append(future) + func = executor.submit if parallelize else FakeParallel + future = func(parallel_get_regions, pred, orig_size, id2label, detection_results[img_idx] if detection_results else None) + postprocessing_futures.append(future) img_idx += 1 - - for future in futures: - results.append(future.result()) - else: - img_idx = 0 - for preds, orig_sizes in layout_generator: - for pred, orig_size in zip(preds, orig_sizes): - results.append(parallel_get_regions( - pred, - orig_size, - id2label, - detection_results[img_idx] if detection_results else None - )) - - img_idx += 1 + except Exception as e: + processing_error.set() + print("Error in layout postprocessing", e) + + # Start producer and consumer threads + producer = threading.Thread(target=inference_producer, daemon=True) + producer.start() + + with ProcessPoolExecutor( + max_workers=max_workers, + mp_context=multiprocessing.get_context("spawn") + ) if parallelize else contextlib.nullcontext() as executor: + consumer = threading.Thread(target=postprocessing_consumer, args=(executor,), daemon=True) + consumer.start() + producer.join() + consumer.join() + + results = [future.result() for future in postprocessing_futures] return results \ No newline at end of file diff --git a/surya/model/recognition/tokenizer.py b/surya/model/recognition/tokenizer.py index d57239a..30018f5 100644 --- a/surya/model/recognition/tokenizer.py +++ b/surya/model/recognition/tokenizer.py @@ -26,7 +26,12 @@ def utf16_numbers_to_text(numbers): byte_array.append(number & 0xFF) # Lower byte byte_array.append((number >> 8) & 0xFF) # Upper byte - text = byte_array.decode('utf-16le', errors="ignore") + try: + text = byte_array.decode('utf-16le', errors="ignore") + except Exception as e: + print(f"Error decoding utf16: {e}") + text = "" + return text diff --git a/surya/util/parallel.py b/surya/util/parallel.py new file mode 100644 index 0000000..015c425 --- /dev/null +++ b/surya/util/parallel.py @@ -0,0 +1,6 @@ +class FakeParallel(): + def __init__(self, func, *args): + self._result = func(*args) + + def result(self): + return self._result