Skip to content

Commit

Permalink
Improve viewer colormap choices (nerfstudio-project#1348)
Browse files Browse the repository at this point in the history
  • Loading branch information
tancik authored Feb 4, 2023
1 parent f7ee1a2 commit 6331cf2
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 49 deletions.
2 changes: 1 addition & 1 deletion nerfstudio/viewer/app/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "viewer",
"homepage": ".",
"version": "23-01-25-0",
"version": "23-02-3-0",
"private": true,
"dependencies": {
"@emotion/react": "^11.10.4",
Expand Down
51 changes: 51 additions & 0 deletions nerfstudio/viewer/app/src/modules/ConfigPanel/ConfigPanel.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ export function RenderControls() {
const colormapChoice = useSelector(
(state) => state.renderingState.colormap_choice,
);
const colormapInvert = useSelector(
(state) => state.renderingState.colormap_invert,
);
const colormapNormalize = useSelector(
(state) => state.renderingState.colormap_normalize,
);
const max_resolution = useSelector(
(state) => state.renderingState.maxResolution,
);
Expand Down Expand Up @@ -115,6 +121,51 @@ export function RenderControls() {
},
disabled: colormapOptions.length === 1,
},
colormap_invert: {
label: '| Invert',
value: colormapInvert,
hint: 'Invert the colormap',
onChange: (v) => {
dispatch_and_send(
websocket,
dispatch,
'renderingState/colormap_invert',
v,
);
},
render: (get) => get('colormap_options') !== 'default',
},
colormap_normalize: {
label: '| Normalize',
value: colormapNormalize,
hint: 'Whether to normalize output between 0 and 1',
onChange: (v) => {
dispatch_and_send(
websocket,
dispatch,
'renderingState/colormap_normalize',
v,
);
},
render: (get) => get('colormap_options') !== 'default',
},
colormap_range: {
label: '| Range',
value: [0, 1],
step: 0.01,
min: -2,
max: 5,
hint: 'Min and max values of the colormap',
onChange: (v) => {
dispatch_and_send(
websocket,
dispatch,
'renderingState/colormap_range',
v,
);
},
render: (get) => get('colormap_options') !== 'default',
},
// Dynamic Resolution
target_train_util: {
label: 'Train Util.',
Expand Down
9 changes: 7 additions & 2 deletions nerfstudio/viewer/app/src/reducer.js
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,17 @@ const initialState = {

all_camera_paths: null, // object containing camera paths and names


isTraining: true,

// colormap options
output_options: ['rgb'], // populated by the possible Graph outputs
output_choice: 'rgb', // the selected output
colormap_options: ['default'], // populated by the output choice
colormap_choice: 'default', // the selected colormap
colormap_invert: false, // whether to invert the colormap
colormap_normalize: false, // whether to normalize the colormap
colormap_range: [0.0, 1.0], // the range of the colormap

maxResolution: 1024,
targetTrainUtil: 0.9,
eval_res: '?',
Expand All @@ -51,7 +56,7 @@ const initialState = {

// Crop Box Options
crop_enabled: false,
crop_bg_color: {r: 38, g:42, b:55},
crop_bg_color: { r: 38, g: 42, b: 55 },
crop_scale: [2.0, 2.0, 2.0],
crop_center: [0.0, 0.0, 0.0],
},
Expand Down
103 changes: 57 additions & 46 deletions nerfstudio/viewer/server/viewer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def setup_viewer(config: cfg.ViewerConfig, log_filename: Path, datapath: str):


class OutputTypes(str, enum.Enum):
"""Noncomprehsnive list of output render types"""
"""Noncomprehensive list of output render types"""

INIT = "init"
RGB = "rgb"
Expand All @@ -92,14 +92,14 @@ class OutputTypes(str, enum.Enum):


class ColormapTypes(str, enum.Enum):
"""Noncomprehsnive list of colormap render types"""
"""List of colormap render types"""

INIT = "init"
DEFAULT = "default"
TURBO = "turbo"
DEPTH = "depth"
SEMANTIC = "semantic"
BOOLEAN = "boolean"
VIRIDIS = "viridis"
MAGMA = "magma"
INFERNO = "inferno"
CIVIDIS = "cividis"


class IOChangeException(Exception):
Expand Down Expand Up @@ -220,12 +220,15 @@ def run(self):

# check colormap type
colormap_type = self.state.vis["renderingState/colormap_choice"].read()
if colormap_type is None:
colormap_type = ColormapTypes.INIT
if self.state.prev_colormap_type != colormap_type:
self.state.check_interrupt_vis = True
return

colormap_range = self.state.vis["renderingState/colormap_range"].read()
if self.state.prev_colormap_range != colormap_range:
self.state.check_interrupt_vis = True
return

# check max render
max_resolution = self.state.vis["renderingState/maxResolution"].read()
if max_resolution is not None:
Expand Down Expand Up @@ -285,7 +288,10 @@ def __init__(self, config: cfg.ViewerConfig, log_filename: Path, datapath: str):
self.prev_camera_matrix = None
self.prev_render_time = 0
self.prev_output_type = OutputTypes.INIT
self.prev_colormap_type = ColormapTypes.INIT
self.prev_colormap_type = None
self.prev_colormap_invert = False
self.prev_colormap_normalize = False
self.prev_colormap_range = [0, 1]
self.prev_moving = False
self.output_type_changed = True
self.max_resolution = 1000
Expand Down Expand Up @@ -583,11 +589,21 @@ def _get_camera_object(self):
self.camera_moving = True

colormap_type = self.vis["renderingState/colormap_choice"].read()
if colormap_type is None:
colormap_type = ColormapTypes.INIT
if self.prev_colormap_type != colormap_type:
self.camera_moving = True

colormap_range = self.vis["renderingState/colormap_range"].read()
if self.prev_colormap_range != colormap_range:
self.camera_moving = True

colormap_invert = self.vis["renderingState/colormap_invert"].read()
if self.prev_colormap_invert != colormap_invert:
self.camera_moving = True

colormap_normalize = self.vis["renderingState/colormap_normalize"].read()
if self.prev_colormap_normalize != colormap_normalize:
self.camera_moving = True

crop_bg_color = self.vis["renderingState/crop_bg_color"].read()
if self.prev_crop_enabled:
if self.prev_crop_bg_color != crop_bg_color:
Expand Down Expand Up @@ -621,37 +637,28 @@ def _apply_colormap(self, outputs: Dict[str, Any], colors: torch.Tensor = None,
return outputs[reformatted_output]

# rendering depth outputs
if self.prev_colormap_type == ColormapTypes.DEPTH or (
self.prev_colormap_type == ColormapTypes.DEFAULT
and outputs[reformatted_output].dtype == torch.float
and (torch.max(outputs[reformatted_output]) - 1.0) > eps # handle floating point arithmetic
):
accumulation_str = (
OutputTypes.ACCUMULATION
if OutputTypes.ACCUMULATION in self.output_list
else OutputTypes.ACCUMULATION_FINE
)
return colormaps.apply_depth_colormap(outputs[reformatted_output], accumulation=outputs[accumulation_str])

# rendering accumulation outputs
if self.prev_colormap_type == ColormapTypes.TURBO or (
self.prev_colormap_type == ColormapTypes.DEFAULT and outputs[reformatted_output].dtype == torch.float
):
return colormaps.apply_colormap(outputs[reformatted_output])
if outputs[reformatted_output].shape[-1] == 1 and outputs[reformatted_output].dtype == torch.float:
output = outputs[reformatted_output]
if self.prev_colormap_normalize:
output = output - torch.min(output)
output = output / (torch.max(output) + eps)
output = output * (self.prev_colormap_range[1] - self.prev_colormap_range[0]) + self.prev_colormap_range[0]
output = torch.clip(output, 0, 1)
if self.prev_colormap_invert:
output = 1 - output
if self.prev_colormap_type == ColormapTypes.DEFAULT:
return colormaps.apply_colormap(output, cmap=ColormapTypes.TURBO.value)
return colormaps.apply_colormap(output, cmap=self.prev_colormap_type)

# rendering semantic outputs
if self.prev_colormap_type == ColormapTypes.SEMANTIC or (
self.prev_colormap_type == ColormapTypes.DEFAULT and outputs[reformatted_output].dtype == torch.int
):
if outputs[reformatted_output].dtype == torch.int:
logits = outputs[reformatted_output]
labels = torch.argmax(torch.nn.functional.softmax(logits, dim=-1), dim=-1) # type: ignore
assert colors is not None
return colors[labels]

# rendering boolean outputs
if self.prev_colormap_type == ColormapTypes.BOOLEAN or (
self.prev_colormap_type == ColormapTypes.DEFAULT and outputs[reformatted_output].dtype == torch.bool
):
if outputs[reformatted_output].dtype == torch.bool:
return colormaps.apply_boolean_colormap(outputs[reformatted_output])

raise NotImplementedError
Expand Down Expand Up @@ -713,13 +720,12 @@ def set_image(self, image):
for video_track in self.video_tracks:
video_track.put_frame(image)

def _send_output_to_viewer(self, outputs: Dict[str, Any], colors: torch.Tensor = None, eps=1e-6):
def _send_output_to_viewer(self, outputs: Dict[str, Any], colors: torch.Tensor = None):
"""Chooses the correct output and sends it to the viewer
Args:
outputs: the dictionary of outputs to choose from, from the graph
colors: is only set if colormap is for semantics. Defaults to None.
eps: epsilon to handle floating point comparisons
"""
if self.output_list is None:
self.output_list = list(outputs.keys())
Expand All @@ -732,16 +738,13 @@ def _send_output_to_viewer(self, outputs: Dict[str, Any], colors: torch.Tensor =

reformatted_output = self._process_invalid_output(self.prev_output_type)
# re-register colormaps and send to viewer
if self.output_type_changed or self.prev_colormap_type == ColormapTypes.INIT:
if self.output_type_changed or self.prev_colormap_type is None:
self.prev_colormap_type = ColormapTypes.DEFAULT
colormap_options = [ColormapTypes.DEFAULT]
if (
outputs[reformatted_output].shape[-1] != 3
and outputs[reformatted_output].dtype == torch.float
and (torch.max(outputs[reformatted_output]) - 1.0) <= eps # handle floating point arithmetic
):
# accumulation can also include depth
colormap_options.extend(["depth"])
colormap_options = []
if outputs[reformatted_output].shape[-1] == 3:
colormap_options = [ColormapTypes.DEFAULT]
if outputs[reformatted_output].shape[-1] == 1 and outputs[reformatted_output].dtype == torch.float:
colormap_options = list(ColormapTypes)
self.output_type_changed = False
self.vis["renderingState/colormap_choice"].write(self.prev_colormap_type)
self.vis["renderingState/colormap_options"].write(colormap_options)
Expand Down Expand Up @@ -885,9 +888,17 @@ def _render_image_in_viewer(self, camera_object, graph: Model, is_training: bool

# check and perform colormap type updates
colormap_type = self.vis["renderingState/colormap_choice"].read()
colormap_type = ColormapTypes.INIT if colormap_type is None else colormap_type
self.prev_colormap_type = colormap_type

colormap_invert = self.vis["renderingState/colormap_invert"].read()
self.prev_colormap_invert = colormap_invert

colormap_normalize = self.vis["renderingState/colormap_normalize"].read()
self.prev_colormap_normalize = colormap_normalize

colormap_range = self.vis["renderingState/colormap_range"].read()
self.prev_colormap_range = colormap_range

# update render aabb
try:
self._update_render_aabb(graph)
Expand Down

0 comments on commit 6331cf2

Please sign in to comment.