Skip to content

Commit

Permalink
Replace Image + PDF transform with OCR transform (#933)
Browse files Browse the repository at this point in the history
* Replace Image + PDF transform with OCR transform

* Use url column instead of file path

* OCR class changes
  • Loading branch information
DhruvaBansal00 authored Nov 13, 2024
1 parent 940cd8a commit 905845e
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/autolabel/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .serper_api import SerperApi
from .serper_maps import SerperMaps
from .custom_api import CustomApi
from .ocr import OCRTransform
from .webpage_transform import WebpageTransform
from .image import ImageTransform
from typing import Dict
Expand All @@ -15,6 +16,7 @@
logger = logging.getLogger(__name__)

TRANSFORM_REGISTRY = {
TransformType.OCR: OCRTransform,
TransformType.PDF: PDFTransform,
TransformType.WEBPAGE_TRANSFORM: WebpageTransform,
TransformType.IMAGE: ImageTransform,
Expand Down
110 changes: 110 additions & 0 deletions src/autolabel/transforms/ocr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from typing import Dict, Any, List

from autolabel.transforms.schema import TransformType
from autolabel.transforms import BaseTransform
from autolabel.cache import BaseCache

from autolabel.transforms.schema import TransformError, TransformErrorType


class OCRTransform(BaseTransform):
"""This class is used to extract text from any document using OCR. The output columns dictionary for this class should include the keys 'content_column' and 'metadata_column'
This transform supports the following image formats: PDF, PNG, JPEG, TIFF, JPEG 2000, GIF, WebP, BMP, and PNM
"""

COLUMN_NAMES = ["content_column"]

def __init__(
self,
cache: BaseCache,
output_columns: Dict[str, Any],
url_column: str,
lang: str = None,
) -> None:
super().__init__(cache, output_columns)
self.url_column = url_column
self.lang = lang

try:
from PIL import Image
import pytesseract
from pdf2image import convert_from_path

self.Image = Image
self.pytesseract = pytesseract
self.convert_from_path = convert_from_path
self.pytesseract.get_tesseract_version()
except ImportError:
raise ImportError(
"pillow, pytesseract, and pdf2image are required to use the ocr transform. Please install pillow, pytesseract, and pdf2image with the following command: pip install pillow pytesseract pdf2image"
)
except EnvironmentError:
raise EnvironmentError(
"The tesseract engine is required to use the ocr transform. Please see https://tesseract-ocr.github.io/tessdoc/Installation.html for installation instructions."
)

@staticmethod
def name() -> str:
return TransformType.OCR

def get_image_ocr(self, image_or_image_path, lang: str = None) -> str:
return self.pytesseract.image_to_string(image_or_image_path, lang=self.lang)

def download_file(self, file_location: str) -> str:
import os
import tempfile
import requests

_, ext = os.path.splitext(file_location)
temp_file = tempfile.NamedTemporaryFile(suffix=ext, delete=False)

# Download file
response = requests.get(file_location)
response.raise_for_status()

# Write to temp file
with open(temp_file.name, "wb") as f:
f.write(response.content)

return temp_file.name

async def _apply(self, row: Dict[str, Any]) -> Dict[str, Any]:
"""This function transforms an image into text using OCR.
Args:
row (Dict[str, Any]): The row of data to be transformed.
Returns:
Dict[str, Any]: The dict of output columns.
"""
curr_file_location = row[self.url_column]
# download file to temp location if a url
try:
curr_file_path = self.download_file(curr_file_location)
except Exception as e:
raise TransformError(
TransformErrorType.TRANSFORM_ERROR,
f"Error downloading file: {e}",
)
ocr_output = []
if curr_file_path.endswith(".pdf"):
pages = self.convert_from_path(curr_file_path)
ocr_output = [self.get_image_ocr(page, lang=self.lang) for page in pages]
else:
ocr_output = [self.get_image_ocr(curr_file_path, lang=self.lang)]

transformed_row = {
self.output_columns["content_column"]: "\n\n".join(ocr_output),
}
return self._return_output_row(transformed_row)

def params(self) -> Dict[str, Any]:
return {
"output_columns": self.output_columns,
"url_column": self.url_column,
"lang": self.lang,
}

def input_columns(self) -> List[str]:
return [self.url_column]
1 change: 1 addition & 0 deletions src/autolabel/transforms/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class TransformType(str, Enum):
WEB_SEARCH_SERPER = "web_search"
MAPS_SEARCH = "map_search"
CUSTOM_API = "custom_api"
OCR = "ocr"


class TransformCacheEntry(BaseModel):
Expand Down

0 comments on commit 905845e

Please sign in to comment.