Skip to content

Commit

Permalink
Add preprocessing common to OCR tasks
Browse files Browse the repository at this point in the history
Add preprocessing to options
  • Loading branch information
UserUnknownFactor committed Aug 16, 2023
1 parent 175acce commit f893a78
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 9 deletions.
34 changes: 27 additions & 7 deletions paddleocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _import_file(module_name, file_path, make_importable=False):
ppstructure = importlib.import_module('ppstructure', 'paddleocr')
from ppocr.utils.logging import get_logger
from tools.infer import predict_system
from ppocr.utils.utility import check_and_read, get_image_file_list
from ppocr.utils.utility import check_and_read, get_image_file_list, alpha_to_color, binarize_img
from ppocr.utils.network import maybe_download, download_with_progressbar, is_link, confirm_model_dir_url
from tools.infer.utility import draw_ocr, str2bool, check_gpu
from ppstructure.utility import init_args, draw_structure_result
Expand Down Expand Up @@ -512,7 +512,7 @@ def get_model_config(type, version, model_type, lang):

def img_decode(content: bytes):
np_arr = np.frombuffer(content, dtype=np.uint8)
return cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)


def check_img(img):
Expand Down Expand Up @@ -616,14 +616,17 @@ def __init__(self, **kwargs):
super().__init__(params)
self.page_num = params.page_num

def ocr(self, img, det=True, rec=True, cls=True):
def ocr(self, img, det=True, rec=True, cls=True, bin=False, inv=False, alpha_color=(255, 255, 255)):
"""
OCR with PaddleOCR
args:
img: img for OCR, support ndarray, img_path and list or ndarray
det: use text detection or not. If False, only rec will be exec. Default is True
rec: use text recognition or not. If False, only det will be exec. Default is True
cls: use angle classifier or not. Default is True. If True, the text with rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance. Text with rotation of 90 or 270 degrees can be recognized even if cls=False.
bin: binarize image to black and white. Default is False.
inv: invert image colors. Default is False.
alpha_color: set RGB color Tuple for transparent parts replacement. Default is pure white.
"""
assert isinstance(img, (np.ndarray, list, str, bytes))
if isinstance(img, list) and det == True:
Expand All @@ -642,9 +645,19 @@ def ocr(self, img, det=True, rec=True, cls=True):
imgs = img[:self.page_num]
else:
imgs = [img]

def preprocess_image(_image):
_image = alpha_to_color(_image, alpha_color)
if inv:
_image = cv2.bitwise_not(_image)
if bin:
_image = binarize_img(_image)
return _image

if det and rec:
ocr_res = []
for idx, img in enumerate(imgs):
img = preprocess_image(img)
dt_boxes, rec_res, _ = self.__call__(img, cls)
if not dt_boxes and not rec_res:
ocr_res.append(None)
Expand All @@ -656,6 +669,7 @@ def ocr(self, img, det=True, rec=True, cls=True):
elif det and not rec:
ocr_res = []
for idx, img in enumerate(imgs):
img = preprocess_image(img)
dt_boxes, elapse = self.text_detector(img)
if not dt_boxes:
ocr_res.append(None)
Expand All @@ -668,6 +682,7 @@ def ocr(self, img, det=True, rec=True, cls=True):
cls_res = []
for idx, img in enumerate(imgs):
if not isinstance(img, list):
img = preprocess_image(img)
img = [img]
if self.use_angle_cls and cls:
img, cls_res_tmp, elapse = self.text_classifier(img)
Expand Down Expand Up @@ -769,10 +784,15 @@ def main():
img_name = os.path.basename(img_path).split('.')[0]
logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
if args.type == 'ocr':
result = engine.ocr(img_path,
det=args.det,
rec=args.rec,
cls=args.use_angle_cls)
result = engine.ocr(
img_path,
det=args.det,
rec=args.rec,
cls=args.use_angle_cls,
bin=args.binarize,
inv=args.invert,
alpha_color=args.alphacolor
)
if result is not None:
for idx in range(len(result)):
res = result[idx]
Expand Down
19 changes: 19 additions & 0 deletions ppocr/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,25 @@ def get_image_file_list(img_file):
imgs_lists = sorted(imgs_lists)
return imgs_lists

def binarize_img(img):
if len(img.shape) == 3 and img.shape[2] == 3:
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # conversion to grayscale image
# use cv2 threshold binarization
_, gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
img = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
return img

def alpha_to_color(img, alpha_color=(255, 255, 255)):
if len(img.shape) == 3 and img.shape[2] == 4:
B, G, R, A = cv2.split(img)
alpha = A / 255

R = (alpha_color[0] * (1 - alpha) + R * alpha).astype(np.uint8)
G = (alpha_color[1] * (1 - alpha) + G * alpha).astype(np.uint8)
B = (alpha_color[2] * (1 - alpha) + B * alpha).astype(np.uint8)

img = cv2.merge((B, G, R))
return img

def check_and_read(img_path):
if os.path.basename(img_path)[-3:] in ['gif', 'GIF']:
Expand Down
17 changes: 16 additions & 1 deletion ppstructure/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import PIL
from PIL import Image, ImageDraw, ImageFont
import numpy as np
from tools.infer.utility import draw_ocr_box_txt, str2bool, init_args as infer_args
from tools.infer.utility import draw_ocr_box_txt, str2bool, str2int_tuple, init_args as infer_args
import math


Expand Down Expand Up @@ -100,6 +100,21 @@ def init_args():
type=str2bool,
default=False,
help='Whether to use pdf2docx api')
parser.add_argument(
"--invert",
type=str2bool,
default=False,
help='Whether to invert image before processing')
parser.add_argument(
"--binarize",
type=str2bool,
default=False,
help='Whether to threshold binarize image before processing')
parser.add_argument(
"--alphacolor",
type=str2int_tuple,
default=(255, 255, 255),
help='Replacement color for the alpha channel, if the latter is present; R,G,B integers')

return parser

Expand Down
4 changes: 3 additions & 1 deletion tools/infer/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@


def str2bool(v):
return v.lower() in ("true", "t", "1")
return v.lower() in ("true", "yes", "t", "y", "1")

def str2int_tuple(v):
return tuple([int(i.strip()) for i in v.split(",")])

def init_args():
parser = argparse.ArgumentParser()
Expand Down

0 comments on commit f893a78

Please sign in to comment.