From 12e2b7cf79d4ae52f459b082d67e5d42457e5584 Mon Sep 17 00:00:00 2001 From: Sun Jiahao <72679458+sunjiahao1999@users.noreply.github.com> Date: Sun, 25 Jun 2023 14:35:18 +0800 Subject: [PATCH] [Enhance] Support LiDAR visualization (#2611) * need fix multi_modality * update multi modal * remove pcd * fix mix bug * fix nusence_mini * fix msehframe * fix flag exit * add space line --- configs/_base_/datasets/semantickitti.py | 2 +- demo/mono_det_demo.py | 2 +- demo/multi_modality_demo.py | 2 +- demo/pcd_demo.py | 2 +- demo/pcd_seg_demo.py | 2 +- mmdet3d/engine/hooks/visualization_hook.py | 14 ++ mmdet3d/visualization/local_visualizer.py | 190 ++++++++++++++++----- 7 files changed, 165 insertions(+), 49 deletions(-) diff --git a/configs/_base_/datasets/semantickitti.py b/configs/_base_/datasets/semantickitti.py index 61c9ef5b66..ae464d8b60 100644 --- a/configs/_base_/datasets/semantickitti.py +++ b/configs/_base_/datasets/semantickitti.py @@ -113,7 +113,7 @@ dataset_type='semantickitti', backend_args=backend_args), dict(type='PointSegClassMapping'), - dict(type='Pack3DDetInputs', keys=['points']) + dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask']) ] # construct a pipeline for data and gt loading in show function # please keep its loading function consistent with test_pipeline (e.g. client) diff --git a/demo/mono_det_demo.py b/demo/mono_det_demo.py index 19ab7ae474..0096a649e5 100644 --- a/demo/mono_det_demo.py +++ b/demo/mono_det_demo.py @@ -59,7 +59,7 @@ def main(args): data_sample=result, draw_gt=False, show=args.show, - wait_time=0, + wait_time=-1, out_file=args.out_dir, pred_score_thr=args.score_thr, vis_task='mono_det') diff --git a/demo/multi_modality_demo.py b/demo/multi_modality_demo.py index 7d7680197e..9f82867d40 100644 --- a/demo/multi_modality_demo.py +++ b/demo/multi_modality_demo.py @@ -67,7 +67,7 @@ def main(args): data_sample=result, draw_gt=False, show=args.show, - wait_time=0, + wait_time=-1, out_file=args.out_dir, pred_score_thr=args.score_thr, vis_task='multi-modality_det') diff --git a/demo/pcd_demo.py b/demo/pcd_demo.py index 4e5cd511d6..6c9266d4e0 100644 --- a/demo/pcd_demo.py +++ b/demo/pcd_demo.py @@ -49,7 +49,7 @@ def main(args): data_sample=result, draw_gt=False, show=args.show, - wait_time=0, + wait_time=-1, out_file=args.out_dir, pred_score_thr=args.score_thr, vis_task='lidar_det') diff --git a/demo/pcd_seg_demo.py b/demo/pcd_seg_demo.py index 2045fd670c..2f1438604b 100644 --- a/demo/pcd_seg_demo.py +++ b/demo/pcd_seg_demo.py @@ -45,7 +45,7 @@ def main(args): data_sample=result, draw_gt=False, show=args.show, - wait_time=0, + wait_time=-1, out_file=args.out_dir, vis_task='lidar_seg') diff --git a/mmdet3d/engine/hooks/visualization_hook.py b/mmdet3d/engine/hooks/visualization_hook.py index 6c2d18d4c5..c59abe3c5f 100644 --- a/mmdet3d/engine/hooks/visualization_hook.py +++ b/mmdet3d/engine/hooks/visualization_hook.py @@ -7,6 +7,7 @@ import numpy as np from mmengine.fileio import get from mmengine.hooks import Hook +from mmengine.logging import print_log from mmengine.runner import Runner from mmengine.utils import mkdir_or_exist from mmengine.visualization import Visualizer @@ -56,6 +57,8 @@ def __init__(self, vis_task: str = 'mono_det', wait_time: float = 0., test_out_dir: Optional[str] = None, + draw_gt: bool = True, + draw_pred: bool = True, backend_args: Optional[dict] = None): self._visualizer: Visualizer = Visualizer.get_current_instance() self.interval = interval @@ -70,11 +73,20 @@ def __init__(self, 'needs to be excluded.') self.vis_task = vis_task + if wait_time == -1: + print_log( + 'Manual control mode, press [Right] to next sample.', + logger='current') + else: + print_log( + 'Autoplay mode, press [SPACE] to pause.', logger='current') self.wait_time = wait_time self.backend_args = backend_args self.draw = draw self.test_out_dir = test_out_dir self._test_index = 0 + self.draw_gt = draw_gt + self.draw_pred = draw_pred def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict, outputs: Sequence[Det3DDataSample]) -> None: @@ -208,6 +220,8 @@ def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, 'test sample', data_input, data_sample=data_sample, + draw_gt=self.draw_gt, + draw_pred=self.draw_pred, show=self.show, vis_task=self.vis_task, wait_time=self.wait_time, diff --git a/mmdet3d/visualization/local_visualizer.py b/mmdet3d/visualization/local_visualizer.py index a860f3e40c..98a71b1ac9 100644 --- a/mmdet3d/visualization/local_visualizer.py +++ b/mmdet3d/visualization/local_visualizer.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy import math +import sys import time from typing import List, Optional, Sequence, Tuple, Union @@ -12,6 +13,7 @@ from matplotlib.path import Path from mmdet.visualization import DetLocalVisualizer, get_palette from mmengine.dist import master_only +from mmengine.logging import print_log from mmengine.structures import InstanceData from mmengine.visualization import Visualizer as MMENGINE_Visualizer from mmengine.visualization.utils import (check_type, color_val_matplotlib, @@ -136,19 +138,24 @@ def __init__( alpha=alpha) if points is not None: self.set_points(points, pcd_mode=pcd_mode, frame_cfg=frame_cfg) - self.pts_seg_num = 0 self.multi_imgs_col = multi_imgs_col self.fig_show_cfg.update(fig_show_cfg) + self.flag_pause = False + self.flag_next = False + self.flag_exit = False + def _clear_o3d_vis(self) -> None: """Clear open3d vis.""" if hasattr(self, 'o3d_vis'): del self.o3d_vis - del self.pcd del self.points_colors + del self.view_control + if hasattr(self, 'pcd'): + del self.pcd - def _initialize_o3d_vis(self, frame_cfg: dict) -> Visualizer: + def _initialize_o3d_vis(self) -> Visualizer: """Initialize open3d vis according to frame_cfg. Args: @@ -161,11 +168,16 @@ def _initialize_o3d_vis(self, frame_cfg: dict) -> Visualizer: if o3d is None or geometry is None: raise ImportError( 'Please run "pip install open3d" to install open3d first.') - o3d_vis = o3d.visualization.Visualizer() + glfw_key_escape = 256 # Esc + glfw_key_space = 32 # Space + glfw_key_right = 262 # Right + o3d_vis = o3d.visualization.VisualizerWithKeyCallback() + o3d_vis.register_key_callback(glfw_key_escape, self.escape_callback) + o3d_vis.register_key_action_callback(glfw_key_space, + self.space_action_callback) + o3d_vis.register_key_callback(glfw_key_right, self.right_callback) o3d_vis.create_window() - # create coordinate frame - mesh_frame = geometry.TriangleMesh.create_coordinate_frame(**frame_cfg) - o3d_vis.add_geometry(mesh_frame) + self.view_control = o3d_vis.get_view_control() return o3d_vis @master_only @@ -205,7 +217,7 @@ def set_points(self, check_type('points', points, np.ndarray) if not hasattr(self, 'o3d_vis'): - self.o3d_vis = self._initialize_o3d_vis(frame_cfg) + self.o3d_vis = self._initialize_o3d_vis() # for now we convert points into depth mode for visualization if pcd_mode != Coord3DMode.DEPTH: @@ -235,6 +247,10 @@ def set_points(self, else: raise NotImplementedError + # create coordinate frame + mesh_frame = geometry.TriangleMesh.create_coordinate_frame(**frame_cfg) + self.o3d_vis.add_geometry(mesh_frame) + pcd.colors = o3d.utility.Vector3dVector(points_colors) self.o3d_vis.add_geometry(pcd) self.pcd = pcd @@ -572,12 +588,15 @@ def draw_seg_mask(self, seg_mask_colors: np.ndarray) -> None: # we can't draw the colors on existing points # in case gt and pred mask would overlap # instead we set a large offset along x-axis for each seg mask - self.pts_seg_num += 1 - offset = (np.array(self.pcd.points).max(0) - - np.array(self.pcd.points).min(0))[0] * 1.2 * self.pts_seg_num - mesh_frame = geometry.TriangleMesh.create_coordinate_frame( - size=1, origin=[offset, 0, 0]) # create coordinate frame for seg - self.o3d_vis.add_geometry(mesh_frame) + if hasattr(self, 'pcd'): + offset = (np.array(self.pcd.points).max(0) - + np.array(self.pcd.points).min(0))[0] * 1.2 + mesh_frame = geometry.TriangleMesh.create_coordinate_frame( + size=1, origin=[offset, 0, + 0]) # create coordinate frame for seg + self.o3d_vis.add_geometry(mesh_frame) + else: + offset = 0 seg_points = copy.deepcopy(seg_mask_colors) seg_points[:, 0] += offset self.set_points(seg_points, pcd_mode=2, vis_mode='add', mode='xyzrgb') @@ -716,7 +735,7 @@ def _draw_pts_sem_seg(self, points: Union[Tensor, np.ndarray], pts_seg: PointData, palette: Optional[List[tuple]] = None, - ignore_index: Optional[int] = None) -> None: + keep_index: Optional[int] = None) -> None: """Draw 3D semantic mask of GT or prediction. Args: @@ -733,14 +752,14 @@ def _draw_pts_sem_seg(self, pts_sem_seg = tensor2ndarray(pts_seg.pts_semantic_mask) palette = np.array(palette) - if ignore_index is not None: - points = points[pts_sem_seg != ignore_index] - pts_sem_seg = pts_sem_seg[pts_sem_seg != ignore_index] + if keep_index is not None: + keep_index = tensor2ndarray(keep_index) + points = points[keep_index] + pts_sem_seg = pts_sem_seg[keep_index] pts_color = palette[pts_sem_seg] seg_color = np.concatenate([points[:, :3], pts_color], axis=1) - self.set_points(points, pcd_mode=2, vis_mode='add') self.draw_seg_mask(seg_color) @master_only @@ -749,8 +768,8 @@ def show(self, drawn_img_3d: Optional[np.ndarray] = None, drawn_img: Optional[np.ndarray] = None, win_name: str = 'image', - wait_time: int = 0, - continue_key: str = ' ', + wait_time: int = -1, + continue_key: str = 'right', vis_task: str = 'lidar_det') -> None: """Show the drawn point cloud/image. @@ -768,10 +787,6 @@ def show(self, means "forever". Defaults to 0. continue_key (str): The key for users to continue. Defaults to ' '. """ - if vis_task == 'multi-modality_det': - img_wait_time = 0.5 - else: - img_wait_time = wait_time # In order to show multi-modal results at the same time, we show image # firstly and then show point cloud since the running of @@ -779,34 +794,119 @@ def show(self, if hasattr(self, '_image'): if drawn_img is None and drawn_img_3d is None: # use the image got by Visualizer.get_image() - super().show(drawn_img_3d, win_name, img_wait_time, - continue_key) - else: - if drawn_img_3d is not None: - super().show(drawn_img_3d, win_name, img_wait_time, - continue_key) - if drawn_img is not None: - super().show(drawn_img, win_name, img_wait_time, + if vis_task == 'multi-modality_det': + import matplotlib.pyplot as plt + is_inline = 'inline' in plt.get_backend() + img = self.get_image() if drawn_img is None else drawn_img + self._init_manager(win_name) + fig = self.manager.canvas.figure + # remove white edges by set subplot margin + fig.subplots_adjust(left=0, right=1, bottom=0, top=1) + fig.clear() + ax = fig.add_subplot() + ax.axis(False) + ax.imshow(img) + self.manager.canvas.draw() + if is_inline: + return fig + else: + fig.show() + self.manager.canvas.flush_events() + else: + super().show(drawn_img_3d, win_name, wait_time, continue_key) + else: + if vis_task == 'multi-modality_det': + import matplotlib.pyplot as plt + is_inline = 'inline' in plt.get_backend() + img = drawn_img if drawn_img_3d is None else drawn_img_3d + self._init_manager(win_name) + fig = self.manager.canvas.figure + # remove white edges by set subplot margin + fig.subplots_adjust(left=0, right=1, bottom=0, top=1) + fig.clear() + ax = fig.add_subplot() + ax.axis(False) + ax.imshow(img) + self.manager.canvas.draw() + if is_inline: + return fig + else: + fig.show() + self.manager.canvas.flush_events() + else: + if drawn_img_3d is not None: + super().show(drawn_img_3d, win_name, wait_time, + continue_key) + if drawn_img is not None: + super().show(drawn_img, win_name, wait_time, + continue_key) if hasattr(self, 'o3d_vis'): - self.o3d_vis.poll_events() + if hasattr(self, 'view_port'): + self.view_control.convert_from_pinhole_camera_parameters( + self.view_port) + self.flag_exit = not self.o3d_vis.poll_events() self.o3d_vis.update_renderer() - if wait_time > 0: - time.sleep(wait_time) + self.view_port = \ + self.view_control.convert_to_pinhole_camera_parameters() # noqa: E501 + if wait_time != -1: + self.last_time = time.time() + while time.time( + ) - self.last_time < wait_time and self.o3d_vis.poll_events(): + self.o3d_vis.update_renderer() + self.view_port = \ + self.view_control.convert_to_pinhole_camera_parameters() # noqa: E501 + while self.flag_pause and self.o3d_vis.poll_events(): + self.o3d_vis.update_renderer() + self.view_port = \ + self.view_control.convert_to_pinhole_camera_parameters() # noqa: E501 + else: - self.o3d_vis.run() + while not self.flag_next and self.o3d_vis.poll_events(): + self.o3d_vis.update_renderer() + self.view_port = \ + self.view_control.convert_to_pinhole_camera_parameters() # noqa: E501 + self.flag_next = False + self.o3d_vis.clear_geometries() + try: + del self.pcd + except KeyError: + pass if save_path is not None: if not (save_path.endswith('.png') or save_path.endswith('.jpg')): save_path += '.png' self.o3d_vis.capture_screen_image(save_path) + if self.flag_exit: + self.o3d_vis.destroy_window() + self.o3d_vis.close() + self._clear_o3d_vis() + sys.exit(0) + + def escape_callback(self, vis): + self.o3d_vis.clear_geometries() + self.o3d_vis.destroy_window() + self.o3d_vis.close() + self._clear_o3d_vis() + sys.exit(0) + + def space_action_callback(self, vis, action, mods): + if action == 1: + if self.flag_pause: + print_log( + 'Playback continued, press [SPACE] to pause.', + logger='current') + else: + print_log( + 'Playback paused, press [SPACE] to continue.', + logger='current') + self.flag_pause = not self.flag_pause + return True - # TODO: support more flexible window control - self.o3d_vis.clear_geometries() - self.o3d_vis.destroy_window() - self.o3d_vis.close() - self._clear_o3d_vis() + def right_callback(self, vis): + self.flag_next = True + return False # TODO: Support Visualize the 3D results from image and point cloud # respectively @@ -862,6 +962,8 @@ def add_datasample(self, # For object detection datasets, no palette is saved palette = self.dataset_meta.get('palette', None) ignore_index = self.dataset_meta.get('ignore_index', None) + if ignore_index is not None and 'gt_pts_seg' in data_sample and vis_task == 'lidar_seg': # noqa: E501 + keep_index = data_sample.gt_pts_seg.pts_semantic_mask != ignore_index # noqa: E501 gt_data_3d = None pred_data_3d = None @@ -890,7 +992,7 @@ def add_datasample(self, assert 'points' in data_input self._draw_pts_sem_seg(data_input['points'], data_sample.gt_pts_seg, palette, - ignore_index) + keep_index) if draw_pred and data_sample is not None: if 'pred_instances_3d' in data_sample: @@ -922,7 +1024,7 @@ def add_datasample(self, assert 'points' in data_input self._draw_pts_sem_seg(data_input['points'], data_sample.pred_pts_seg, palette, - ignore_index) + keep_index) # monocular 3d object detection image if vis_task in ['mono_det', 'multi-modality_det']: