Skip to content

Commit

Permalink
visualize_reconstruction fixes
Browse files Browse the repository at this point in the history
Summary: Various fixes to get visualize_reconstruction running, and an interactive test for it.

Reviewed By: kjchalup

Differential Revision: D39286691

fbshipit-source-id: 88735034cc01736b24735bcb024577e6ab7ed336
  • Loading branch information
bottler authored and facebook-github-bot committed Sep 8, 2022
1 parent 34ad77b commit 6e25fe8
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 58 deletions.
4 changes: 2 additions & 2 deletions projects/implicitron_trainer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ If you have a custom `experiment.py` script (as in the Option 2 above), replace
To run training, pass a yaml config file, followed by a list of overridden arguments.
For example, to train NeRF on the first skateboard sequence from CO3D dataset, you can run:
```shell
dataset_args=data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
dataset_args=data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf \
$dataset_args.dataset_root=<DATASET_ROOT> $dataset_args.category='skateboard' \
$dataset_args.test_restrict_sequence_id=0 test_when_finished=True exp_dir=<CHECKPOINT_DIR>
Expand All @@ -87,7 +87,7 @@ To run evaluation on the latest checkpoint after (or during) training, simply ad

E.g. for executing the evaluation on the NeRF skateboard sequence, you can run:
```shell
dataset_args=data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
dataset_args=data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf \
$dataset_args.dataset_root=<CO3D_DATASET_ROOT> $dataset_args.category='skateboard' \
$dataset_args.test_restrict_sequence_id=0 exp_dir=<CHECKPOINT_DIR> eval_only=True
Expand Down
11 changes: 1 addition & 10 deletions projects/implicitron_trainer/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,7 @@
from omegaconf import OmegaConf

from .. import experiment
from .utils import intercept_logs


def interactive_testing_requested() -> bool:
"""
Certain tests are only useful when run interactively, and so are not regularly run.
These are activated by this funciton returning True, which the user requests by
setting the environment variable `PYTORCH3D_INTERACTIVE_TESTING` to 1.
"""
return os.environ.get("PYTORCH3D_INTERACTIVE_TESTING", "") == "1"
from .utils import interactive_testing_requested, intercept_logs


internal = os.environ.get("FB_TEST", False)
Expand Down
27 changes: 27 additions & 0 deletions projects/implicitron_trainer/tests/test_visualize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import unittest

from .. import visualize_reconstruction
from .utils import interactive_testing_requested

internal = os.environ.get("FB_TEST", False)


class TestVisualize(unittest.TestCase):
def test_from_defaults(self):
if not interactive_testing_requested():
return
checkpoint_dir = os.environ["exp_dir"]
argv = [
f"exp_dir={checkpoint_dir}",
"n_eval_cameras=40",
"render_size=[64,64]",
"video_size=[256,256]",
]
visualize_reconstruction.main(argv)
10 changes: 10 additions & 0 deletions projects/implicitron_trainer/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import contextlib
import logging
import os
import re


Expand All @@ -28,3 +29,12 @@ def filter(self, record):
yield intercepted_messages
finally:
logger.removeFilter(interceptor)


def interactive_testing_requested() -> bool:
"""
Certain tests are only useful when run interactively, and so are not regularly run.
These are activated by this funciton returning True, which the user requests by
setting the environment variable `PYTORCH3D_INTERACTIVE_TESTING` to 1.
"""
return os.environ.get("PYTORCH3D_INTERACTIVE_TESTING", "") == "1"
58 changes: 36 additions & 22 deletions projects/implicitron_trainer/visualize_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Script to visualize a previously trained model. Example call:
"""
Script to visualize a previously trained model. Example call:
projects/implicitron_trainer/visualize_reconstruction.py
exp_dir='./exps/checkpoint_dir' visdom_show_preds=True visdom_port=8097
pytorch3d_implicitron_visualizer \
exp_dir='./exps/checkpoint_dir' visdom_show_preds=True visdom_port=8097 \
n_eval_cameras=40 render_size="[64,64]" video_size="[256,256]"
"""

Expand All @@ -18,9 +19,9 @@

import numpy as np
import torch
from omegaconf import OmegaConf
from pytorch3d.implicitron.models.visualization import render_flyaround
from pytorch3d.implicitron.tools.configurable import get_default_args
from omegaconf import DictConfig, OmegaConf
from pytorch3d.implicitron.models.visualization.render_flyaround import render_flyaround
from pytorch3d.implicitron.tools.config import enable_get_default_args, get_default_args

from .experiment import Experiment

Expand All @@ -38,7 +39,7 @@ def visualize_reconstruction(
visdom_server: str = "http://127.0.0.1",
visdom_port: int = 8097,
visdom_env: Optional[str] = None,
):
) -> None:
"""
Given an `exp_dir` containing a trained Implicitron model, generates videos consisting
of renderes of sequences from the dataset used to train and evaluate the trained
Expand Down Expand Up @@ -76,22 +77,27 @@ def visualize_reconstruction(
config = _get_config_from_experiment_directory(exp_dir)
config.exp_dir = exp_dir
# important so that the CO3D dataset gets loaded in full
dataset_args = (
config.data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
)
dataset_args.test_on_train = False
data_source_args = config.data_source_ImplicitronDataSource_args
if "dataset_map_provider_JsonIndexDatasetMapProvider_args" in data_source_args:
dataset_args = (
data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
)
dataset_args.test_on_train = False
if restrict_sequence_name is not None:
dataset_args.restrict_sequence_name = restrict_sequence_name

# Set the rendering image size
model_factory_args = config.model_factory_ImplicitronModelFactory_args
model_factory_args.force_resume = True
model_args = model_factory_args.model_GenericModel_args
model_args.render_image_width = render_size[0]
model_args.render_image_height = render_size[1]
if restrict_sequence_name is not None:
dataset_args.restrict_sequence_name = restrict_sequence_name

# Load the previously trained model
experiment = Experiment(config)
model = experiment.model_factory(force_resume=True)
model.cuda()
experiment = Experiment(**config)
model = experiment.model_factory(exp_dir=exp_dir)
device = torch.device("cuda")
model.to(device)
model.eval()

# Setup the dataset
Expand All @@ -101,6 +107,11 @@ def visualize_reconstruction(
if dataset is None:
raise ValueError(f"{split} dataset not provided")

if visdom_env is None:
visdom_env = (
"visualizer_" + config.training_loop_ImplicitronTrainingLoop_args.visdom_env
)

# iterate over the sequences in the dataset
for sequence_name in dataset.sequence_names():
with torch.no_grad():
Expand All @@ -114,23 +125,26 @@ def visualize_reconstruction(
n_flyaround_poses=n_eval_cameras,
visdom_server=visdom_server,
visdom_port=visdom_port,
visdom_environment=f"visualizer_{config.visdom_env}"
if visdom_env is None
else visdom_env,
visdom_environment=visdom_env,
video_resize=video_size,
device=device,
)


def _get_config_from_experiment_directory(experiment_directory):
enable_get_default_args(visualize_reconstruction)


def _get_config_from_experiment_directory(experiment_directory) -> DictConfig:
cfg_file = os.path.join(experiment_directory, "expconfig.yaml")
config = OmegaConf.load(cfg_file)
# pyre-ignore[7]
return config


def main(argv):
def main(argv) -> None:
# automatically parses arguments of visualize_reconstruction
cfg = OmegaConf.create(get_default_args(visualize_reconstruction))
cfg.update(OmegaConf.from_cli())
cfg.update(OmegaConf.from_cli(argv))
with torch.no_grad():
visualize_reconstruction(**cfg)

Expand Down
8 changes: 7 additions & 1 deletion pytorch3d/implicitron/dataset/single_sequence_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# provide data for a single scene.

from dataclasses import field
from typing import Iterable, List, Optional
from typing import Iterable, Iterator, List, Optional, Tuple

import numpy as np
import torch
Expand Down Expand Up @@ -46,6 +46,12 @@ def sequence_names(self) -> Iterable[str]:
def __len__(self) -> int:
return len(self.poses)

def sequence_frames_in_order(
self, seq_name: str
) -> Iterator[Tuple[float, int, int]]:
for i in range(len(self)):
yield (0.0, i, i)

def __getitem__(self, index) -> FrameData:
if index >= len(self):
raise IndexError(f"index {index} out of range {len(self)}")
Expand Down
15 changes: 8 additions & 7 deletions pytorch3d/implicitron/models/visualization/render_flyaround.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def render_flyaround(
"depths_render",
"_all_source_images",
),
):
) -> None:
"""
Uses `model` to generate a video consisting of renders of a scene imaged from
a camera flying around the scene. The scene is specified with the `dataset` object and
Expand Down Expand Up @@ -133,6 +133,7 @@ def render_flyaround(
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
assert all(train_data.sequence_name[0] == sn for sn in train_data.sequence_name)
# pyre-ignore[6]
sequence_set_name = "train" if is_train_frame(train_data.frame_type)[0] else "test"
logger.info(f"Sequence set = {sequence_set_name}.")
train_cameras = train_data.camera
Expand Down Expand Up @@ -209,7 +210,7 @@ def render_flyaround(

def _load_whole_dataset(
dataset: torch.utils.data.Dataset, idx: Sequence[int], num_workers: int = 10
):
) -> FrameData:
load_all_dataloader = torch.utils.data.DataLoader(
torch.utils.data.Subset(dataset, idx),
batch_size=len(idx),
Expand All @@ -220,7 +221,7 @@ def _load_whole_dataset(
return next(iter(load_all_dataloader))


def _images_from_preds(preds: Dict[str, Any]):
def _images_from_preds(preds: Dict[str, Any]) -> Dict[str, torch.Tensor]:
imout = {}
for k in (
"image_rgb",
Expand Down Expand Up @@ -253,7 +254,7 @@ def _images_from_preds(preds: Dict[str, Any]):
return imout


def _stack_images(ims: torch.Tensor, size: Optional[Tuple[int, int]]):
def _stack_images(ims: torch.Tensor, size: Optional[Tuple[int, int]]) -> torch.Tensor:
ba = ims.shape[0]
H = int(np.ceil(np.sqrt(ba)))
W = H
Expand Down Expand Up @@ -281,7 +282,7 @@ def _show_predictions(
),
n_samples=10,
one_image_width=200,
):
) -> None:
"""Given a list of predictions visualize them into a single image using visdom."""
assert isinstance(preds, list)

Expand Down Expand Up @@ -329,7 +330,7 @@ def _generate_prediction_videos(
video_path: str = "/tmp/video",
video_frames_dir: Optional[str] = None,
resize: Optional[Tuple[int, int]] = None,
):
) -> None:
"""Given a list of predictions create and visualize rotating videos of the
objects using visdom.
"""
Expand Down Expand Up @@ -359,7 +360,7 @@ def _generate_prediction_videos(
)

for k in predicted_keys:
vws[k].get_video(quiet=True)
vws[k].get_video()
logger.info(f"Generated {vws[k].out_path}.")
if viz is not None:
viz.video(
Expand Down
50 changes: 34 additions & 16 deletions pytorch3d/implicitron/tools/video_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import os
import shutil
import subprocess
import tempfile
import warnings
from typing import Optional, Tuple, Union
Expand All @@ -15,6 +16,7 @@
import numpy as np
from PIL import Image

_DEFAULT_FFMPEG = os.environ.get("FFMPEG", "ffmpeg")

matplotlib.use("Agg")

Expand All @@ -27,13 +29,13 @@ class VideoWriter:
def __init__(
self,
cache_dir: Optional[str] = None,
ffmpeg_bin: str = "ffmpeg",
ffmpeg_bin: str = _DEFAULT_FFMPEG,
out_path: str = "/tmp/video.mp4",
fps: int = 20,
output_format: str = "visdom",
rmdir_allowed: bool = False,
**kwargs,
):
) -> None:
"""
Args:
cache_dir: A directory for storing the video frames. If `None`,
Expand Down Expand Up @@ -74,7 +76,7 @@ def write_frame(
self,
frame: Union[matplotlib.figure.Figure, np.ndarray, Image.Image, str],
resize: Optional[Union[float, Tuple[int, int]]] = None,
):
) -> None:
"""
Write a frame to the video.
Expand Down Expand Up @@ -114,7 +116,7 @@ def write_frame(
self.frames.append(outfile)
self.frame_num += 1

def get_video(self, quiet: bool = True):
def get_video(self) -> str:
"""
Generate the video from the written frames.
Expand All @@ -127,23 +129,39 @@ def get_video(self, quiet: bool = True):

regexp = os.path.join(self.cache_dir, self.regexp)

if self.output_format == "visdom": # works for ppt too
ffmcmd_ = (
"%s -r %d -i %s -vcodec h264 -f mp4 \
-y -crf 18 -b 2000k -pix_fmt yuv420p '%s'"
% (self.ffmpeg_bin, self.fps, regexp, self.out_path)
if shutil.which(self.ffmpeg_bin) is None:
raise ValueError(
f"Cannot find ffmpeg as `{self.ffmpeg_bin}`. "
+ "Please set FFMPEG in the environment or ffmpeg_bin on this class."
)
else:
raise ValueError("no such output type %s" % str(self.output_format))

if quiet:
ffmcmd_ += " > /dev/null 2>&1"
if self.output_format == "visdom": # works for ppt too
args = [
self.ffmpeg_bin,
"-r",
str(self.fps),
"-i",
regexp,
"-vcodec",
"h264",
"-f",
"mp4",
"-y",
"-crf",
"18",
"-b",
"2000k",
"-pix_fmt",
"yuv420p",
self.out_path,
]

subprocess.check_call(args)
else:
print(ffmcmd_)
os.system(ffmcmd_)
raise ValueError("no such output type %s" % str(self.output_format))

return self.out_path

def __del__(self):
def __del__(self) -> None:
if self.tmp_dir is not None:
self.tmp_dir.cleanup()

0 comments on commit 6e25fe8

Please sign in to comment.