From 7136f0b8fcab49e57d49784efe78bebefd78caf3 Mon Sep 17 00:00:00 2001 From: CastleDream Date: Mon, 24 Jul 2023 18:34:39 +0800 Subject: [PATCH 1/5] update SegLocVisualize --- mmseg/apis/inference.py | 8 ++- mmseg/utils/__init__.py | 6 +- mmseg/visualization/local_visualizer.py | 73 ++++++++++++++++++++++--- 3 files changed, 75 insertions(+), 12 deletions(-) diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index 81cd17d798..6a398ebc5e 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -158,6 +158,7 @@ def show_result_pyplot(model: BaseSegmentor, draw_pred: bool = True, wait_time: float = 0, show: bool = True, + withLabels: Optional[bool] = True, save_dir=None, out_file=None): """Visualize the segmentation results on the image. @@ -177,10 +178,14 @@ def show_result_pyplot(model: BaseSegmentor, that means "forever". Defaults to 0. show (bool): Whether to display the drawn image. Default to True. + withLabels(bool, optional): Add semantic labels in visualization + result, Default to True. save_dir (str, optional): Save file dir for all storage backends. If it is None, the backend storage will not save any data. out_file (str, optional): Path to output file. Default to None. + + Returns: np.ndarray: the drawn image which channel is RGB. """ @@ -208,7 +213,8 @@ def show_result_pyplot(model: BaseSegmentor, draw_pred=draw_pred, wait_time=wait_time, out_file=out_file, - show=show) + show=show, + withLabels=withLabels) vis_img = visualizer.get_image() return vis_img diff --git a/mmseg/utils/__init__.py b/mmseg/utils/__init__.py index cb1436c198..8c5073fbe2 100644 --- a/mmseg/utils/__init__.py +++ b/mmseg/utils/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. # yapf: disable -from .class_names import (ade_classes, ade_palette, cityscapes_classes, +from .class_names import (ade_classes, ade_palette, bdd100k_classes, + bdd100k_palette, cityscapes_classes, cityscapes_palette, cocostuff_classes, cocostuff_palette, dataset_aliases, get_classes, get_palette, isaid_classes, isaid_palette, @@ -27,5 +28,6 @@ 'cityscapes_palette', 'ade_palette', 'voc_palette', 'cocostuff_palette', 'loveda_palette', 'potsdam_palette', 'vaihingen_palette', 'isaid_palette', 'stare_palette', 'dataset_aliases', 'get_classes', 'get_palette', - 'datafrombytes', 'synapse_palette', 'synapse_classes' + 'datafrombytes', 'synapse_palette', 'synapse_classes', bdd100k_classes, + bdd100k_palette ] diff --git a/mmseg/visualization/local_visualizer.py b/mmseg/visualization/local_visualizer.py index 0d693e5820..41b2cbf263 100644 --- a/mmseg/visualization/local_visualizer.py +++ b/mmseg/visualization/local_visualizer.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Dict, List, Optional +import cv2 import mmcv import numpy as np from mmengine.dist import master_only @@ -42,8 +43,8 @@ class SegLocalVisualizer(Visualizer): >>> import numpy as np >>> import torch >>> from mmengine.structures import PixelData - >>> from mmseg.data import SegDataSample - >>> from mmseg.engine.visualization import SegLocalVisualizer + >>> from mmseg.structures import SegDataSample + >>> from mmseg.visualization import SegLocalVisualizer >>> seg_local_visualizer = SegLocalVisualizer() >>> image = np.random.randint(0, 256, @@ -60,7 +61,7 @@ class SegLocalVisualizer(Visualizer): >>> seg_local_visualizer.add_datasample( ... 'visualizer_example', image, ... gt_seg_data_sample, show=True) - """ # noqa + """ # noqa def __init__(self, name: str = 'visualizer', @@ -76,9 +77,33 @@ def __init__(self, self.alpha: float = alpha self.set_dataset_meta(palette, classes, dataset_name) - def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData, + def _get_center_loc(self, mask: np.ndarray) -> np.ndarray: + """Get semantic seg center coordinate. + + Args: + mask: np.ndarray: get from sem_seg + """ + mask = mask.astype(np.uint8) + loc = np.argwhere(mask == 1) + + loc_sort = np.array( + sorted(loc.tolist(), key=lambda row: (row[0], row[1]))) + y_list = loc_sort[:, 0] + unique, indices, counts = np.unique( + y_list, return_index=True, return_counts=True) + y_loc = unique[counts.argmax()] + y_most_freq_loc = loc[loc_sort[:, 0] == y_loc] + center_num = len(y_most_freq_loc) // 2 + x = y_most_freq_loc[center_num][1] + y = y_most_freq_loc[center_num][0] + return np.array([x, y]) + + def _draw_sem_seg(self, + image: np.ndarray, + sem_seg: PixelData, classes: Optional[List], - palette: Optional[List]) -> np.ndarray: + palette: Optional[List], + withLabels: Optional[bool] = True) -> np.ndarray: """Draw semantic seg of GT or prediction. Args: @@ -94,6 +119,8 @@ def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData, palette (list, optional): Input palette for result rendering, which is a list of color palette responding to the classes. Defaults to None. + withLabels(bool, optional): Add semantic labels in visualization + result, Default to True. Returns: np.ndarray: the drawn image which channel is RGB. @@ -112,6 +139,31 @@ def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData, for label, color in zip(labels, colors): mask[sem_seg[0] == label, :] = color + if withLabels: + font = cv2.FONT_HERSHEY_SIMPLEX + fontScale = 1 + fontColor = (255, 255, 255) + thickness = 3 + lineType = 2 + masks = sem_seg[0] == labels[:, None, None] + for mask_num in range(len(labels)): + classes_id = labels[mask_num] + classes_color = colors[mask_num] + loc = self._get_center_loc(masks[mask_num]) + text = classes[classes_id] + (label_width, label_height), baseline = cv2.getTextSize( + text, font, fontScale, thickness) + mask = cv2.rectangle(mask, loc, + (loc[0] + label_width + baseline, + loc[1] + label_height + baseline), + classes_color, -1) + mask = cv2.rectangle(mask, loc, + (loc[0] + label_width + baseline, + loc[1] + label_height + baseline), + (0, 0, 0), 2) + mask = cv2.putText(mask, text, (loc[0], loc[1] + label_height), + font, fontScale, fontColor, thickness, + lineType) color_seg = (image * (1 - self.alpha) + mask * self.alpha).astype( np.uint8) self.set_image(color_seg) @@ -137,7 +189,7 @@ def set_dataset_meta(self, visulizer will use the meta information of the dataset i.e. classes and palette, but the `classes` and `palette` have higher priority. Defaults to None. - """ # noqa + """ # noqa # Set default value. When calling # `SegLocalVisualizer().dataset_meta=xxx`, # it will override the default value. @@ -161,7 +213,8 @@ def add_datasample( wait_time: float = 0, # TODO: Supported in mmengine's Viusalizer. out_file: Optional[str] = None, - step: int = 0) -> None: + step: int = 0, + withLabels: Optional[bool] = True) -> None: """Draw datasample and save to all backends. - If GT and prediction are plotted at the same time, they are @@ -187,6 +240,8 @@ def add_datasample( wait_time (float): The interval of show (s). Defaults to 0. out_file (str): Path to output file. Defaults to None. step (int): Global step value to record. Defaults to 0. + withLabels(bool, optional): Add semantic labels in visualization + result, Defaults to True. """ classes = self.dataset_meta.get('classes', None) palette = self.dataset_meta.get('palette', None) @@ -202,7 +257,7 @@ def add_datasample( 'segmentation results.' gt_img_data = self._draw_sem_seg(gt_img_data, data_sample.gt_sem_seg, classes, - palette) + palette, withLabels) if (draw_pred and data_sample is not None and 'pred_sem_seg' in data_sample): @@ -213,7 +268,7 @@ def add_datasample( 'segmentation results.' pred_img_data = self._draw_sem_seg(pred_img_data, data_sample.pred_sem_seg, - classes, palette) + classes, palette, withLabels) if gt_img_data is not None and pred_img_data is not None: drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1) From 5ba642c9adf0358d99be37c550afaadd050fc92e Mon Sep 17 00:00:00 2001 From: CastleDream Date: Mon, 24 Jul 2023 19:18:52 +0800 Subject: [PATCH 2/5] add quote --- mmseg/utils/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mmseg/utils/__init__.py b/mmseg/utils/__init__.py index 8c5073fbe2..f69043764a 100644 --- a/mmseg/utils/__init__.py +++ b/mmseg/utils/__init__.py @@ -28,6 +28,6 @@ 'cityscapes_palette', 'ade_palette', 'voc_palette', 'cocostuff_palette', 'loveda_palette', 'potsdam_palette', 'vaihingen_palette', 'isaid_palette', 'stare_palette', 'dataset_aliases', 'get_classes', 'get_palette', - 'datafrombytes', 'synapse_palette', 'synapse_classes', bdd100k_classes, - bdd100k_palette + 'datafrombytes', 'synapse_palette', 'synapse_classes', 'bdd100k_classes', + 'bdd100k_palette' ] From 47fd9bd3adb61df6bff2e407417eb95c88c93460 Mon Sep 17 00:00:00 2001 From: CastleDream Date: Mon, 24 Jul 2023 19:45:29 +0800 Subject: [PATCH 3/5] update type --- mmseg/visualization/local_visualizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmseg/visualization/local_visualizer.py b/mmseg/visualization/local_visualizer.py index 41b2cbf263..af42b11d0e 100644 --- a/mmseg/visualization/local_visualizer.py +++ b/mmseg/visualization/local_visualizer.py @@ -83,7 +83,6 @@ def _get_center_loc(self, mask: np.ndarray) -> np.ndarray: Args: mask: np.ndarray: get from sem_seg """ - mask = mask.astype(np.uint8) loc = np.argwhere(mask == 1) loc_sort = np.array( @@ -146,6 +145,7 @@ def _draw_sem_seg(self, thickness = 3 lineType = 2 masks = sem_seg[0] == labels[:, None, None] + masks = masks.astype(np.uint8) for mask_num in range(len(labels)): classes_id = labels[mask_num] classes_color = colors[mask_num] From b06768ad1bf85aef273e379fadfde6fded76d645 Mon Sep 17 00:00:00 2001 From: CastleDream Date: Mon, 24 Jul 2023 22:39:37 +0800 Subject: [PATCH 4/5] fix sem_seg torch --- mmseg/visualization/local_visualizer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mmseg/visualization/local_visualizer.py b/mmseg/visualization/local_visualizer.py index af42b11d0e..41d9e8b3af 100644 --- a/mmseg/visualization/local_visualizer.py +++ b/mmseg/visualization/local_visualizer.py @@ -4,6 +4,7 @@ import cv2 import mmcv import numpy as np +import torch from mmengine.dist import master_only from mmengine.structures import PixelData from mmengine.visualization import Visualizer @@ -144,7 +145,11 @@ def _draw_sem_seg(self, fontColor = (255, 255, 255) thickness = 3 lineType = 2 - masks = sem_seg[0] == labels[:, None, None] + + if isinstance(sem_seg[0], torch.Tensor): + masks = sem_seg[0].numpy() == labels[:, None, None] + else: + masks = sem_seg[0] == labels[:, None, None] masks = masks.astype(np.uint8) for mask_num in range(len(labels)): classes_id = labels[mask_num] From 3c20608080b34d9724ecd52266af04c9f19d6295 Mon Sep 17 00:00:00 2001 From: CastleDream Date: Mon, 24 Jul 2023 23:06:42 +0800 Subject: [PATCH 5/5] update self-adaptive to small image --- mmseg/visualization/local_visualizer.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/mmseg/visualization/local_visualizer.py b/mmseg/visualization/local_visualizer.py index 41d9e8b3af..ebd310784e 100644 --- a/mmseg/visualization/local_visualizer.py +++ b/mmseg/visualization/local_visualizer.py @@ -141,9 +141,16 @@ def _draw_sem_seg(self, if withLabels: font = cv2.FONT_HERSHEY_SIMPLEX - fontScale = 1 + # (0,1] to change the size of the text relative to the image + scale = 0.05 + fontScale = min(image.shape[0], image.shape[1]) / (25 / scale) fontColor = (255, 255, 255) - thickness = 3 + if image.shape[0] < 300 or image.shape[1] < 300: + thickness = 1 + rectangleThickness = 1 + else: + thickness = 2 + rectangleThickness = 2 lineType = 2 if isinstance(sem_seg[0], torch.Tensor): @@ -165,7 +172,7 @@ def _draw_sem_seg(self, mask = cv2.rectangle(mask, loc, (loc[0] + label_width + baseline, loc[1] + label_height + baseline), - (0, 0, 0), 2) + (0, 0, 0), rectangleThickness) mask = cv2.putText(mask, text, (loc[0], loc[1] + label_height), font, fontScale, fontColor, thickness, lineType)