Skip to content

Commit

Permalink
Improve viewer banner when server is bound to 0.0.0.0 (#2761)
Browse files Browse the repository at this point in the history
* Improve viewer banner when server is bound to 0.0.0.0

* Add comment

* Update viewer.py
  • Loading branch information
brentyi authored Jan 13, 2024
1 parent 44dafa3 commit 69ed6d4
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 19 deletions.
9 changes: 5 additions & 4 deletions nerfstudio/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
import functools
import os
import time
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from threading import Lock
from typing import Dict, List, Literal, Optional, Tuple, Type, cast, DefaultDict
from collections import defaultdict
from typing import DefaultDict, Dict, List, Literal, Optional, Tuple, Type, cast

import torch
from nerfstudio.configs.experiment_config import ExperimentConfig
from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes, TrainingCallbackLocation
Expand All @@ -36,8 +37,8 @@
from nerfstudio.utils.misc import step_check
from nerfstudio.utils.rich_utils import CONSOLE
from nerfstudio.utils.writer import EventName, TimeWriter
from nerfstudio.viewer_legacy.server.viewer_state import ViewerLegacyState
from nerfstudio.viewer.viewer import Viewer as ViewerState
from nerfstudio.viewer_legacy.server.viewer_state import ViewerLegacyState
from rich import box, style
from rich.panel import Panel
from rich.table import Table
Expand Down Expand Up @@ -182,7 +183,7 @@ def setup(self, test_mode: Literal["test", "val", "inference"] = "val") -> None:
train_lock=self.train_lock,
share=self.config.viewer.make_share_url,
)
banner_messages = [f"Viewer at: {self.viewer_state.viewer_url}"]
banner_messages = self.viewer_state.viewer_info
self._check_viewer_warnings()

self._load_checkpoint()
Expand Down
4 changes: 2 additions & 2 deletions nerfstudio/scripts/viewer/run_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _start_viewer(config: TrainerConfig, pipeline: Pipeline, step: int):
datapath=pipeline.datamanager.get_datapath(),
pipeline=pipeline,
)
banner_messages = [f"Legaccy viewer at: {viewer_state.viewer_url}"]
banner_messages = [f"Legacy viewer at: {viewer_state.viewer_url}"]
if config.vis == "viewer":
viewer_state = ViewerState(
config.viewer,
Expand All @@ -108,7 +108,7 @@ def _start_viewer(config: TrainerConfig, pipeline: Pipeline, step: int):
pipeline=pipeline,
share=config.viewer.make_share_url,
)
banner_messages = [f"Viewer at: {viewer_state.viewer_url}"]
banner_messages = viewer_state.viewer_info

# We don't need logging, but writer.GLOBAL_BUFFER needs to be populated
config.logging.local_writer.enable = False
Expand Down
2 changes: 1 addition & 1 deletion nerfstudio/utils/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def _print_stats(self, latest_map, padding=" "):

for i, mssg in enumerate(self.past_mssgs):
pad_len = len(max(self.past_mssgs, key=len))
style = "\x1b[6;30;42m" if self.banner_len and i >= len(self.past_mssgs) - self.banner_len + 1 else ""
style = "\x1b[30;42m" if self.banner_len and i >= len(self.past_mssgs) - self.banner_len + 1 else ""
print(f"{style}{mssg:{padding}<{pad_len}} \x1b[0m")
else:
print(curr_mssg)
32 changes: 20 additions & 12 deletions nerfstudio/viewer/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import numpy as np
import torch
import torchvision
import viser
import viser.theme
import viser.transforms as vtf
from nerfstudio.cameras.camera_optimizers import CameraOptimizer
Expand All @@ -34,15 +33,17 @@
from nerfstudio.pipelines.base_pipeline import Pipeline
from nerfstudio.utils.decorators import check_main_thread, decorate_all
from nerfstudio.utils.writer import GLOBAL_BUFFER, EventName
from nerfstudio.viewer_legacy.server import viewer_utils
from nerfstudio.viewer.control_panel import ControlPanel
from nerfstudio.viewer.export_panel import populate_export_tab
from nerfstudio.viewer.render_panel import populate_render_tab
from nerfstudio.viewer.render_state_machine import RenderAction, RenderStateMachine
from nerfstudio.viewer.utils import CameraState, parse_object
from nerfstudio.viewer.viewer_elements import ViewerControl, ViewerElement
from nerfstudio.viewer_legacy.server import viewer_utils
from typing_extensions import assert_never

import viser

if TYPE_CHECKING:
from nerfstudio.engine.trainer import Trainer

Expand All @@ -63,13 +64,12 @@ class Viewer:
share: print a shareable URL
Attributes:
viewer_url: url to open viewer
viewer_info: information string for the viewer
viser_server: the viser server
"""

viewer_url: str
viewer_info: List[str]
viser_server: viser.ViserServer
camera_state: Optional[CameraState] = None

def __init__(
self,
Expand Down Expand Up @@ -107,15 +107,23 @@ def __init__(

self.viser_server = viser.ViserServer(host=config.websocket_host, port=websocket_port)
# Set the name of the URL either to the share link if available, or the localhost
share_url = None
if share:
url = self.viser_server.request_share_url()
if url is not None:
print("Couldn't make share URL")
self.viewer_url = url
else:
self.viewer_url = f"http://{config.websocket_host}:{websocket_port}"
share_url = self.viser_server.request_share_url()
if share_url is None:
print("Couldn't make share URL!")

if share_url is not None:
self.viewer_info = [f"Viewer at: http://localhost:{websocket_port} or {share_url}"]
elif config.websocket_host == "0.0.0.0":
# 0.0.0.0 is not a real IP address and was confusing people, so
# we'll just print localhost instead. There are some security
# (and IPv6 compatibility) implications here though, so we should
# note that the server is bound to 0.0.0.0!
self.viewer_info = [f"Viewer running locally at: http://localhost:{websocket_port} (listening on 0.0.0.0)"]
else:
self.viewer_url = f"http://{config.websocket_host}:{websocket_port}"
self.viewer_info = [f"Viewer running locally at: http://{config.websocket_host}:{websocket_port}"]

buttons = (
viser.theme.TitlebarButton(
text="Getting Started",
Expand Down

0 comments on commit 69ed6d4

Please sign in to comment.