-
Notifications
You must be signed in to change notification settings - Fork 0
/
ex_draw_seg_mask.py
29 lines (16 loc) · 941 Bytes
/
ex_draw_seg_mask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from tensorrt_yolov8 import EngineHelper
from tensorrt_yolov8.models.utils import draw_segmentation_results, get_scaled_segmentation_masks, get_printable_masks
import cv2
if __name__ == "__main__":
model_path = "yolov8s_seg_b1_fp32.engine"
image_path = "demo_img.jpg"
segmentation = EngineHelper("segmentation", model_path)
image = cv2.imread(image_path)
results = segmentation(image, min_prob=0.5, top_k=3)
# Obtain the per-object segmentation mask with respect to the original image
masks = get_scaled_segmentation_masks(results, image.shape[:2], wrt_original=True)
# Obtain the printable model output segmentation mask
model_masks = get_printable_masks(results, image.shape[:2])
# Print the segmentation mask as yolo outputs them
for i, mask in enumerate(model_masks):
cv2.imwrite(f"{image_path.split('.jpg')[0]}_mask{i}_{results[i].class_label}.jpg", mask)