From 6331cf2234fb7d5262a6dcbae46fbb4c732cc669 Mon Sep 17 00:00:00 2001 From: Matthew Tancik Date: Fri, 3 Feb 2023 23:15:05 -0800 Subject: [PATCH] Improve viewer colormap choices (#1348) --- nerfstudio/viewer/app/package.json | 2 +- .../src/modules/ConfigPanel/ConfigPanel.jsx | 51 +++++++++ nerfstudio/viewer/app/src/reducer.js | 9 +- nerfstudio/viewer/server/viewer_utils.py | 103 ++++++++++-------- 4 files changed, 116 insertions(+), 49 deletions(-) diff --git a/nerfstudio/viewer/app/package.json b/nerfstudio/viewer/app/package.json index 37f2b7bbbd..171ec8dd12 100644 --- a/nerfstudio/viewer/app/package.json +++ b/nerfstudio/viewer/app/package.json @@ -1,7 +1,7 @@ { "name": "viewer", "homepage": ".", - "version": "23-01-25-0", + "version": "23-02-3-0", "private": true, "dependencies": { "@emotion/react": "^11.10.4", diff --git a/nerfstudio/viewer/app/src/modules/ConfigPanel/ConfigPanel.jsx b/nerfstudio/viewer/app/src/modules/ConfigPanel/ConfigPanel.jsx index a1c331fad8..66b92458e5 100644 --- a/nerfstudio/viewer/app/src/modules/ConfigPanel/ConfigPanel.jsx +++ b/nerfstudio/viewer/app/src/modules/ConfigPanel/ConfigPanel.jsx @@ -37,6 +37,12 @@ export function RenderControls() { const colormapChoice = useSelector( (state) => state.renderingState.colormap_choice, ); + const colormapInvert = useSelector( + (state) => state.renderingState.colormap_invert, + ); + const colormapNormalize = useSelector( + (state) => state.renderingState.colormap_normalize, + ); const max_resolution = useSelector( (state) => state.renderingState.maxResolution, ); @@ -115,6 +121,51 @@ export function RenderControls() { }, disabled: colormapOptions.length === 1, }, + colormap_invert: { + label: '| Invert', + value: colormapInvert, + hint: 'Invert the colormap', + onChange: (v) => { + dispatch_and_send( + websocket, + dispatch, + 'renderingState/colormap_invert', + v, + ); + }, + render: (get) => get('colormap_options') !== 'default', + }, + colormap_normalize: { + label: '| Normalize', + value: colormapNormalize, + hint: 'Whether to normalize output between 0 and 1', + onChange: (v) => { + dispatch_and_send( + websocket, + dispatch, + 'renderingState/colormap_normalize', + v, + ); + }, + render: (get) => get('colormap_options') !== 'default', + }, + colormap_range: { + label: '| Range', + value: [0, 1], + step: 0.01, + min: -2, + max: 5, + hint: 'Min and max values of the colormap', + onChange: (v) => { + dispatch_and_send( + websocket, + dispatch, + 'renderingState/colormap_range', + v, + ); + }, + render: (get) => get('colormap_options') !== 'default', + }, // Dynamic Resolution target_train_util: { label: 'Train Util.', diff --git a/nerfstudio/viewer/app/src/reducer.js b/nerfstudio/viewer/app/src/reducer.js index 6de1ea7c25..d8335c80d5 100644 --- a/nerfstudio/viewer/app/src/reducer.js +++ b/nerfstudio/viewer/app/src/reducer.js @@ -30,12 +30,17 @@ const initialState = { all_camera_paths: null, // object containing camera paths and names - isTraining: true, + + // colormap options output_options: ['rgb'], // populated by the possible Graph outputs output_choice: 'rgb', // the selected output colormap_options: ['default'], // populated by the output choice colormap_choice: 'default', // the selected colormap + colormap_invert: false, // whether to invert the colormap + colormap_normalize: false, // whether to normalize the colormap + colormap_range: [0.0, 1.0], // the range of the colormap + maxResolution: 1024, targetTrainUtil: 0.9, eval_res: '?', @@ -51,7 +56,7 @@ const initialState = { // Crop Box Options crop_enabled: false, - crop_bg_color: {r: 38, g:42, b:55}, + crop_bg_color: { r: 38, g: 42, b: 55 }, crop_scale: [2.0, 2.0, 2.0], crop_center: [0.0, 0.0, 0.0], }, diff --git a/nerfstudio/viewer/server/viewer_utils.py b/nerfstudio/viewer/server/viewer_utils.py index b59381f26d..ce0858bb40 100644 --- a/nerfstudio/viewer/server/viewer_utils.py +++ b/nerfstudio/viewer/server/viewer_utils.py @@ -82,7 +82,7 @@ def setup_viewer(config: cfg.ViewerConfig, log_filename: Path, datapath: str): class OutputTypes(str, enum.Enum): - """Noncomprehsnive list of output render types""" + """Noncomprehensive list of output render types""" INIT = "init" RGB = "rgb" @@ -92,14 +92,14 @@ class OutputTypes(str, enum.Enum): class ColormapTypes(str, enum.Enum): - """Noncomprehsnive list of colormap render types""" + """List of colormap render types""" - INIT = "init" DEFAULT = "default" TURBO = "turbo" - DEPTH = "depth" - SEMANTIC = "semantic" - BOOLEAN = "boolean" + VIRIDIS = "viridis" + MAGMA = "magma" + INFERNO = "inferno" + CIVIDIS = "cividis" class IOChangeException(Exception): @@ -220,12 +220,15 @@ def run(self): # check colormap type colormap_type = self.state.vis["renderingState/colormap_choice"].read() - if colormap_type is None: - colormap_type = ColormapTypes.INIT if self.state.prev_colormap_type != colormap_type: self.state.check_interrupt_vis = True return + colormap_range = self.state.vis["renderingState/colormap_range"].read() + if self.state.prev_colormap_range != colormap_range: + self.state.check_interrupt_vis = True + return + # check max render max_resolution = self.state.vis["renderingState/maxResolution"].read() if max_resolution is not None: @@ -285,7 +288,10 @@ def __init__(self, config: cfg.ViewerConfig, log_filename: Path, datapath: str): self.prev_camera_matrix = None self.prev_render_time = 0 self.prev_output_type = OutputTypes.INIT - self.prev_colormap_type = ColormapTypes.INIT + self.prev_colormap_type = None + self.prev_colormap_invert = False + self.prev_colormap_normalize = False + self.prev_colormap_range = [0, 1] self.prev_moving = False self.output_type_changed = True self.max_resolution = 1000 @@ -583,11 +589,21 @@ def _get_camera_object(self): self.camera_moving = True colormap_type = self.vis["renderingState/colormap_choice"].read() - if colormap_type is None: - colormap_type = ColormapTypes.INIT if self.prev_colormap_type != colormap_type: self.camera_moving = True + colormap_range = self.vis["renderingState/colormap_range"].read() + if self.prev_colormap_range != colormap_range: + self.camera_moving = True + + colormap_invert = self.vis["renderingState/colormap_invert"].read() + if self.prev_colormap_invert != colormap_invert: + self.camera_moving = True + + colormap_normalize = self.vis["renderingState/colormap_normalize"].read() + if self.prev_colormap_normalize != colormap_normalize: + self.camera_moving = True + crop_bg_color = self.vis["renderingState/crop_bg_color"].read() if self.prev_crop_enabled: if self.prev_crop_bg_color != crop_bg_color: @@ -621,37 +637,28 @@ def _apply_colormap(self, outputs: Dict[str, Any], colors: torch.Tensor = None, return outputs[reformatted_output] # rendering depth outputs - if self.prev_colormap_type == ColormapTypes.DEPTH or ( - self.prev_colormap_type == ColormapTypes.DEFAULT - and outputs[reformatted_output].dtype == torch.float - and (torch.max(outputs[reformatted_output]) - 1.0) > eps # handle floating point arithmetic - ): - accumulation_str = ( - OutputTypes.ACCUMULATION - if OutputTypes.ACCUMULATION in self.output_list - else OutputTypes.ACCUMULATION_FINE - ) - return colormaps.apply_depth_colormap(outputs[reformatted_output], accumulation=outputs[accumulation_str]) - - # rendering accumulation outputs - if self.prev_colormap_type == ColormapTypes.TURBO or ( - self.prev_colormap_type == ColormapTypes.DEFAULT and outputs[reformatted_output].dtype == torch.float - ): - return colormaps.apply_colormap(outputs[reformatted_output]) + if outputs[reformatted_output].shape[-1] == 1 and outputs[reformatted_output].dtype == torch.float: + output = outputs[reformatted_output] + if self.prev_colormap_normalize: + output = output - torch.min(output) + output = output / (torch.max(output) + eps) + output = output * (self.prev_colormap_range[1] - self.prev_colormap_range[0]) + self.prev_colormap_range[0] + output = torch.clip(output, 0, 1) + if self.prev_colormap_invert: + output = 1 - output + if self.prev_colormap_type == ColormapTypes.DEFAULT: + return colormaps.apply_colormap(output, cmap=ColormapTypes.TURBO.value) + return colormaps.apply_colormap(output, cmap=self.prev_colormap_type) # rendering semantic outputs - if self.prev_colormap_type == ColormapTypes.SEMANTIC or ( - self.prev_colormap_type == ColormapTypes.DEFAULT and outputs[reformatted_output].dtype == torch.int - ): + if outputs[reformatted_output].dtype == torch.int: logits = outputs[reformatted_output] labels = torch.argmax(torch.nn.functional.softmax(logits, dim=-1), dim=-1) # type: ignore assert colors is not None return colors[labels] # rendering boolean outputs - if self.prev_colormap_type == ColormapTypes.BOOLEAN or ( - self.prev_colormap_type == ColormapTypes.DEFAULT and outputs[reformatted_output].dtype == torch.bool - ): + if outputs[reformatted_output].dtype == torch.bool: return colormaps.apply_boolean_colormap(outputs[reformatted_output]) raise NotImplementedError @@ -713,13 +720,12 @@ def set_image(self, image): for video_track in self.video_tracks: video_track.put_frame(image) - def _send_output_to_viewer(self, outputs: Dict[str, Any], colors: torch.Tensor = None, eps=1e-6): + def _send_output_to_viewer(self, outputs: Dict[str, Any], colors: torch.Tensor = None): """Chooses the correct output and sends it to the viewer Args: outputs: the dictionary of outputs to choose from, from the graph colors: is only set if colormap is for semantics. Defaults to None. - eps: epsilon to handle floating point comparisons """ if self.output_list is None: self.output_list = list(outputs.keys()) @@ -732,16 +738,13 @@ def _send_output_to_viewer(self, outputs: Dict[str, Any], colors: torch.Tensor = reformatted_output = self._process_invalid_output(self.prev_output_type) # re-register colormaps and send to viewer - if self.output_type_changed or self.prev_colormap_type == ColormapTypes.INIT: + if self.output_type_changed or self.prev_colormap_type is None: self.prev_colormap_type = ColormapTypes.DEFAULT - colormap_options = [ColormapTypes.DEFAULT] - if ( - outputs[reformatted_output].shape[-1] != 3 - and outputs[reformatted_output].dtype == torch.float - and (torch.max(outputs[reformatted_output]) - 1.0) <= eps # handle floating point arithmetic - ): - # accumulation can also include depth - colormap_options.extend(["depth"]) + colormap_options = [] + if outputs[reformatted_output].shape[-1] == 3: + colormap_options = [ColormapTypes.DEFAULT] + if outputs[reformatted_output].shape[-1] == 1 and outputs[reformatted_output].dtype == torch.float: + colormap_options = list(ColormapTypes) self.output_type_changed = False self.vis["renderingState/colormap_choice"].write(self.prev_colormap_type) self.vis["renderingState/colormap_options"].write(colormap_options) @@ -885,9 +888,17 @@ def _render_image_in_viewer(self, camera_object, graph: Model, is_training: bool # check and perform colormap type updates colormap_type = self.vis["renderingState/colormap_choice"].read() - colormap_type = ColormapTypes.INIT if colormap_type is None else colormap_type self.prev_colormap_type = colormap_type + colormap_invert = self.vis["renderingState/colormap_invert"].read() + self.prev_colormap_invert = colormap_invert + + colormap_normalize = self.vis["renderingState/colormap_normalize"].read() + self.prev_colormap_normalize = colormap_normalize + + colormap_range = self.vis["renderingState/colormap_range"].read() + self.prev_colormap_range = colormap_range + # update render aabb try: self._update_render_aabb(graph)