diff --git a/pyroengine/utils.py b/pyroengine/utils.py index 72b0a84f..17dd0575 100644 --- a/pyroengine/utils.py +++ b/pyroengine/utils.py @@ -61,7 +61,7 @@ def letterbox( im_b = np.zeros((h + top + bottom, w + left + right, 3)) + color im_b[top : top + h, left : left + w, :] = im - return im_b.astype("uint8") + return im_b.astype("uint8"), (left, top) def box_iou(box1: np.ndarray, box2: np.ndarray, eps: float = 1e-7): diff --git a/pyroengine/vision.py b/pyroengine/vision.py index 53ab89e1..a9ea22dd 100644 --- a/pyroengine/vision.py +++ b/pyroengine/vision.py @@ -4,7 +4,7 @@ # See LICENSE or go to for full license details. import os -from typing import Optional +from typing import Optional, Tuple from urllib.request import urlretrieve import numpy as np @@ -41,26 +41,27 @@ def __init__(self, model_path: Optional[str] = "data/model.onnx", img_size: tupl self.ort_session = onnxruntime.InferenceSession(model_path) self.img_size = img_size - def preprocess_image(self, pil_img: Image.Image, mask: Optional[np.ndarray] = None) -> np.ndarray: + def preprocess_image(self, pil_img: Image.Image) -> Tuple[np.ndarray, Tuple[int, int]]: """Preprocess an image for inference Args: - pil_img: a valid pillow image - mask: occlusion mask to drop prediction in an area + pil_img: A valid PIL image. Returns: - the resized and normalized image of shape (1, C, H, W) + A tuple containing: + - The resized and normalized image of shape (1, C, H, W). + - Padding information as a tuple of integers (pad_height, pad_width). """ - np_img = letterbox(np.array(pil_img), self.img_size) # letterbox - np_img = np.expand_dims(np_img.astype("float"), axis=0) - np_img = np.ascontiguousarray(np_img.transpose((0, 3, 1, 2))) # BHWC to BCHW - np_img = np_img.astype("float32") / 255 + np_img, pad = letterbox(np.array(pil_img), self.img_size) # Applies letterbox resize with padding + np_img = np.expand_dims(np_img.astype("float"), axis=0) # Add batch dimension + np_img = np.ascontiguousarray(np_img.transpose((0, 3, 1, 2))) # Convert from BHWC to BCHW format + np_img = np_img.astype("float32") / 255 # Normalize to [0, 1] - return np_img + return np_img, pad def __call__(self, pil_img: Image.Image, occlusion_mask: Optional[np.ndarray] = None) -> np.ndarray: - np_img = self.preprocess_image(pil_img) + np_img, pad = self.preprocess_image(pil_img) # ONNX inference y = self.ort_session.run(["output0"], {"images": np_img})[0][0] @@ -72,10 +73,15 @@ def __call__(self, pil_img: Image.Image, occlusion_mask: Optional[np.ndarray] = # Sort by confidence y = y[y[:, 4].argsort()] y = nms(y) + # Normalize preds if len(y) > 0: - y[:, :4:2] /= self.img_size[1] - y[:, 1:4:2] /= self.img_size[0] + # Remove padding + left_pad, top_pad = pad + y[:, :4:2] -= left_pad + y[:, 1:4:2] -= top_pad + y[:, :4:2] /= self.img_size[1] - 2 * left_pad + y[:, 1:4:2] /= self.img_size[0] - 2 * top_pad else: y = np.zeros((0, 5)) # normalize output diff --git a/tests/test_vision.py b/tests/test_vision.py index 5c895de9..dada6dbe 100644 --- a/tests/test_vision.py +++ b/tests/test_vision.py @@ -7,9 +7,10 @@ def test_classifier(mock_wildfire_image): # Instantiate the ONNX model model = Classifier() # Check preprocessing - out = model.preprocess_image(mock_wildfire_image) + out, pad = model.preprocess_image(mock_wildfire_image) assert isinstance(out, np.ndarray) and out.dtype == np.float32 assert out.shape == (1, 3, 384, 640) + assert isinstance(pad, tuple) # Check inference out = model(mock_wildfire_image) assert out.shape == (1, 5)