diff --git a/paddleocr.py b/paddleocr.py index dc92cbf6b7..36980aec44 100644 --- a/paddleocr.py +++ b/paddleocr.py @@ -516,7 +516,19 @@ def img_decode(content: bytes): return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED) -def check_img(img): +def check_img(img, alpha_color=(255, 255, 255)): + """ + Check the image data. If it is another type of image file, try to decode it into a numpy array. + The inference network requires three-channel images, So the following channel conversions are done + single channel image: Gray to RGB R←Y,G←Y,B←Y + four channel image: alpha_to_color + args: + img: image data + file format: jpg, png and other image formats that opencv can decode, as well as gif and pdf formats + storage type: binary image, net image file, local image file + alpha_color: Background color in images in RGBA format + return: numpy.array (h, w, 3) + """ if isinstance(img, bytes): img = img_decode(img) if isinstance(img, str): @@ -550,9 +562,12 @@ def check_img(img): if img is None: logger.error("error in loading image:{}".format(image_file)) return None + # single channel image array.shape:h,w if isinstance(img, np.ndarray) and len(img.shape) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - + # four channel image array.shape:h,w,c + if isinstance(img, np.ndarray) and len(img.shape) == 3 and img.shape[2] == 4: + img = alpha_to_color(img, alpha_color) return img @@ -638,7 +653,7 @@ def ocr(self, img, det=True, rec=True, cls=True, bin=False, inv=False, alpha_col 'Since the angle classifier is not initialized, it will not be used during the forward process' ) - img = check_img(img) + img = check_img(img, alpha_color) # for infer pdf file if isinstance(img, list): if self.page_num > len(img) or self.page_num == 0: @@ -648,7 +663,6 @@ def ocr(self, img, det=True, rec=True, cls=True, bin=False, inv=False, alpha_col imgs = [img] def preprocess_image(_image): - _image = alpha_to_color(_image, alpha_color) if inv: _image = cv2.bitwise_not(_image) if bin: @@ -755,8 +769,8 @@ def __init__(self, **kwargs): logger.debug(params) super().__init__(params) - def __call__(self, img, return_ocr_result_in_table=False, img_idx=0): - img = check_img(img) + def __call__(self, img, return_ocr_result_in_table=False, img_idx=0, alpha_color=(255, 255, 255)): + img = check_img(img, alpha_color) res, _ = super().__call__( img, return_ocr_result_in_table, img_idx=img_idx) return res