diff --git a/src/autolabel/transforms/ocr.py b/src/autolabel/transforms/ocr.py index e3014302..ce198ae9 100644 --- a/src/autolabel/transforms/ocr.py +++ b/src/autolabel/transforms/ocr.py @@ -19,7 +19,8 @@ class OCRTransform(BaseTransform): - """Extract text from documents using OCR. + """ + Extract text from documents using OCR. The output columns dictionary for this class should include the keys 'content_column' and 'metadata_column'. @@ -37,7 +38,8 @@ def __init__( url_column: str, lang: str | None = None, ) -> None: - """Initialize OCRTransform. + """ + Initialize OCRTransform. Args: cache: Cache instance to use @@ -51,32 +53,23 @@ def __init__( self.url_column = url_column self.lang = lang try: - import pytesseract from pdf2image import convert_from_path - self.pytesseract = pytesseract self.convert_from_path = convert_from_path - self.pytesseract.get_tesseract_version() self.client = boto3.client("textract") except ImportError: msg = ( - "pillow, pytesseract, and pdf2image are required to use ocr" - "Please install with: pip install pillow pytesseract pdf2image" + "pillow, pdf2image are required to use ocr" + "Please install with: pip install pillow pdf2image" ) raise ImportError(msg) from None - except OSError: - msg = ( - "The tesseract engine is required to use the ocr transform. " - "Please see https://tesseract-ocr.github.io/tessdoc/Installation.html " - "for installation instructions." - ) - raise OSError(msg) from None @staticmethod def name() -> str: - """Get transform name. + """ + Get transform name. Returns: Transform type name @@ -89,7 +82,8 @@ def default_ocr_processor( image_or_image_path: Image.Image | str, lang: str | None = None, ) -> str: - """Extract text from image using OCR. + """ + Extract text from image using OCR. Args: image_or_image_path: PIL Image or path to image file @@ -117,7 +111,8 @@ def default_ocr_processor( return "\n".join([block.get("Text", "") for block in blocks]) def download_file(self, file_location: str) -> str: - """Download file from URL to temporary location. + """ + Download file from URL to temporary location. Args: file_location: URL or path of file to download @@ -141,7 +136,8 @@ def download_file(self, file_location: str) -> str: return temp_file.name async def _apply(self, row: dict[str, Any]) -> dict[str, Any]: - """Transform document into text using OCR. + """ + Transform document into text using OCR. Args: row: Row of data to transform @@ -177,7 +173,8 @@ async def _apply(self, row: dict[str, Any]) -> dict[str, Any]: return self._return_output_row(transformed_row) def params(self) -> dict[str, Any]: - """Get transform parameters. + """ + Get transform parameters. Returns: Dict of parameters @@ -190,7 +187,8 @@ def params(self) -> dict[str, Any]: } def input_columns(self) -> list[str]: - """Get required input columns. + """ + Get required input columns. Returns: List of input column names