Skip to content

Commit

Permalink
pulsar interface unification.
Browse files Browse the repository at this point in the history
Summary:
This diff builds on top of the `pulsar integration` diff to provide a unified interface for the existing PyTorch3D point renderer and Pulsar. For more information about the pulsar backend, see the release notes and the paper (https://arxiv.org/abs/2004.07484). For information on how to use the backend, see the point cloud rendering notebook and the examples in the folder docs/examples.

The unified interfaces are completely consistent. Switching the render backend is as easy as using `renderer = PulsarPointsRenderer(rasterizer=rasterizer).to(device)` instead of `renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor)` and adding the `gamma` parameter to the forward function. All PyTorch3D camera types are supported as far as possible; keyword arguments are properly forwarded to the camera. The `PerspectiveCamera` and `OrthographicCamera` require znear and zfar as additional parameters for the forward pass.

Reviewed By: nikhilaravi

Differential Revision: D21421443

fbshipit-source-id: 4aa0a83a419592d9a0bb5d62486a1cdea9d73ce6
  • Loading branch information
classner authored and facebook-github-bot committed Nov 3, 2020
1 parent b19fe1d commit 960fd6d
Show file tree
Hide file tree
Showing 18 changed files with 695 additions and 313 deletions.
1 change: 1 addition & 0 deletions pytorch3d/renderer/points/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

from .compositor import AlphaCompositor, NormWeightedCompositor
from .pulsar.unified import PulsarPointsRenderer
from .rasterize_points import rasterize_points
from .rasterizer import PointsRasterizationSettings, PointsRasterizer
from .renderer import PointsRenderer
Expand Down
17 changes: 13 additions & 4 deletions pytorch3d/renderer/points/pulsar/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def _transform_cam_params(
height: int,
orthogonal: bool,
right_handed: bool,
first_R_then_T: bool = False,
) -> Tuple[
torch.Tensor,
torch.Tensor,
Expand Down Expand Up @@ -401,6 +402,8 @@ def _transform_cam_params(
(does not use focal length).
* right_handed: bool, whether to use a right handed system
(negative z in camera direction).
* first_R_then_T: bool, whether to first rotate, then translate
the camera (PyTorch3D convention).
Returns:
* pos_vec: the position vector in 3D,
Expand Down Expand Up @@ -460,16 +463,18 @@ def _transform_cam_params(
# Always get quadratic pixels.
pixel_size_x = sensor_size_x / float(width)
sensor_size_y = height * pixel_size_x
if continuous_rep:
rot_mat = rotation_6d_to_matrix(rot_vec)
else:
rot_mat = axis_angle_to_matrix(rot_vec)
if first_R_then_T:
pos_vec = torch.matmul(rot_mat, pos_vec[..., None])[:, :, 0]
LOGGER.debug(
"Camera position: %s, rotation: %s. Focal length: %s.",
str(pos_vec),
str(rot_vec),
str(focal_length),
)
if continuous_rep:
rot_mat = rotation_6d_to_matrix(rot_vec)
else:
rot_mat = axis_angle_to_matrix(rot_vec)
sensor_dir_x = torch.matmul(
rot_mat,
torch.tensor(
Expand Down Expand Up @@ -576,6 +581,7 @@ def forward(
max_n_hits: int = _C.MAX_UINT,
mode: int = 0,
return_forward_info: bool = False,
first_R_then_T: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""
Rendering pass to create an image from the provided spheres and camera
Expand Down Expand Up @@ -616,6 +622,8 @@ def forward(
the float encoded integer index of a sphere and its weight. They are the
five spheres with the highest color contribution to this pixel color,
ordered descending. Default: False.
* first_R_then_T: bool, whether to first apply rotation to the camera,
then translation (PyTorch3D convention). Default: False.
Returns:
* image: [Bx]HxWx3 float tensor with the resulting image.
Expand All @@ -638,6 +646,7 @@ def forward(
self._renderer.height,
self._renderer.orthogonal,
self._renderer.right_handed,
first_R_then_T=first_R_then_T,
)
if (
focal_lengths.min().item() > 0.0
Expand Down
Loading

0 comments on commit 960fd6d

Please sign in to comment.