Skip to content

Commit

Permalink
Remove pytesseract references
Browse files Browse the repository at this point in the history
  • Loading branch information
DhruvaBansal00 committed Nov 14, 2024
1 parent 7a93153 commit 9c4c11a
Showing 1 changed file with 18 additions and 20 deletions.
38 changes: 18 additions & 20 deletions src/autolabel/transforms/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@


class OCRTransform(BaseTransform):
"""Extract text from documents using OCR.
"""
Extract text from documents using OCR.
The output columns dictionary for this class should include the keys 'content_column'
and 'metadata_column'.
Expand All @@ -37,7 +38,8 @@ def __init__(
url_column: str,
lang: str | None = None,
) -> None:
"""Initialize OCRTransform.
"""
Initialize OCRTransform.
Args:
cache: Cache instance to use
Expand All @@ -51,32 +53,23 @@ def __init__(
self.url_column = url_column
self.lang = lang
try:
import pytesseract
from pdf2image import convert_from_path

self.pytesseract = pytesseract
self.convert_from_path = convert_from_path
self.pytesseract.get_tesseract_version()

self.client = boto3.client("textract")

except ImportError:
msg = (
"pillow, pytesseract, and pdf2image are required to use ocr"
"Please install with: pip install pillow pytesseract pdf2image"
"pillow, pdf2image are required to use ocr"
"Please install with: pip install pillow pdf2image"
)
raise ImportError(msg) from None
except OSError:
msg = (
"The tesseract engine is required to use the ocr transform. "
"Please see https://tesseract-ocr.github.io/tessdoc/Installation.html "
"for installation instructions."
)
raise OSError(msg) from None

@staticmethod
def name() -> str:
"""Get transform name.
"""
Get transform name.
Returns:
Transform type name
Expand All @@ -89,7 +82,8 @@ def default_ocr_processor(
image_or_image_path: Image.Image | str,
lang: str | None = None,
) -> str:
"""Extract text from image using OCR.
"""
Extract text from image using OCR.
Args:
image_or_image_path: PIL Image or path to image file
Expand Down Expand Up @@ -117,7 +111,8 @@ def default_ocr_processor(
return "\n".join([block.get("Text", "") for block in blocks])

def download_file(self, file_location: str) -> str:
"""Download file from URL to temporary location.
"""
Download file from URL to temporary location.
Args:
file_location: URL or path of file to download
Expand All @@ -141,7 +136,8 @@ def download_file(self, file_location: str) -> str:
return temp_file.name

async def _apply(self, row: dict[str, Any]) -> dict[str, Any]:
"""Transform document into text using OCR.
"""
Transform document into text using OCR.
Args:
row: Row of data to transform
Expand Down Expand Up @@ -177,7 +173,8 @@ async def _apply(self, row: dict[str, Any]) -> dict[str, Any]:
return self._return_output_row(transformed_row)

def params(self) -> dict[str, Any]:
"""Get transform parameters.
"""
Get transform parameters.
Returns:
Dict of parameters
Expand All @@ -190,7 +187,8 @@ def params(self) -> dict[str, Any]:
}

def input_columns(self) -> list[str]:
"""Get required input columns.
"""
Get required input columns.
Returns:
List of input column names
Expand Down

0 comments on commit 9c4c11a

Please sign in to comment.