Skip to content

Commit

Permalink
Add utils to approximate the conical frustums as multivariate gaussians.
Browse files Browse the repository at this point in the history
Summary:
Introduce methods to approximate the radii of conical frustums along rays as described in [MipNerf](https://arxiv.org/abs/2103.13415):

- Two new attributes are added to ImplicitronRayBundle: bins and radii. Bins is of size n_pts_per_ray + 1. It allows us to manipulate easily and n_pts_per_ray intervals. For example we need the intervals coordinates in the radii computation for \(t_{\mu}, t_{\delta}\). Radii are used to store the radii of the conical frustums.

- Add 3 new methods to compute the radii:
   - approximate_conical_frustum_as_gaussians: It computes the mean along the ray direction, the variance of the
      conical frustum  with respect to t and variance of the conical frustum with respect to its radius. This
      implementation follows the stable computation defined in the paper.
   - compute_3d_diagonal_covariance_gaussian: Will leverage the two previously computed variances to find the
     diagonal covariance of the Gaussian.
   - conical_frustum_to_gaussian: Mix everything together to compute the means and the diagonal covariances along
     the ray of the Gaussians.

- In AbstractMaskRaySampler, introduces the attribute `cast_ray_bundle_as_cone`. If False it won't change the previous behaviour of the RaySampler. However if True, the samplers will sample `n_pts_per_ray +1` instead of `n_pts_per_ray`. This points are then used to set the bins attribute of ImplicitronRayBundle. The support of HeterogeneousRayBundle has not been added since the current code does not allow it. A safeguard has been added to avoid a silent bug in the future.

Reviewed By: shapovalov

Differential Revision: D45269190

fbshipit-source-id: bf22fad12d71d55392f054e3f680013aa0d59b78
  • Loading branch information
EmGarr authored and facebook-github-bot committed Jul 6, 2023
1 parent 4e7715c commit 29b8ebd
Show file tree
Hide file tree
Showing 10 changed files with 978 additions and 66 deletions.
4 changes: 4 additions & 0 deletions projects/implicitron_trainer/tests/experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ model_factory_ImplicitronModelFactory_args:
n_rays_total_training: null
stratified_point_sampling_training: true
stratified_point_sampling_evaluation: false
cast_ray_bundle_as_cone: false
scene_extent: 8.0
scene_center:
- 0.0
Expand All @@ -228,6 +229,7 @@ model_factory_ImplicitronModelFactory_args:
n_rays_total_training: null
stratified_point_sampling_training: true
stratified_point_sampling_evaluation: false
cast_ray_bundle_as_cone: false
min_depth: 0.1
max_depth: 8.0
renderer_LSTMRenderer_args:
Expand Down Expand Up @@ -642,6 +644,7 @@ model_factory_ImplicitronModelFactory_args:
n_rays_total_training: null
stratified_point_sampling_training: true
stratified_point_sampling_evaluation: false
cast_ray_bundle_as_cone: false
scene_extent: 8.0
scene_center:
- 0.0
Expand All @@ -654,6 +657,7 @@ model_factory_ImplicitronModelFactory_args:
n_rays_total_training: null
stratified_point_sampling_training: true
stratified_point_sampling_evaluation: false
cast_ray_bundle_as_cone: false
min_depth: 0.1
max_depth: 8.0
renderer_LSTMRenderer_args:
Expand Down
212 changes: 212 additions & 0 deletions pytorch3d/implicitron/models/renderer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch
from pytorch3d.implicitron.tools.config import ReplaceableBase
from pytorch3d.ops import packed_to_padded
from pytorch3d.renderer.implicit.utils import ray_bundle_variables_to_ray_points


class EvaluationMode(Enum):
Expand Down Expand Up @@ -47,6 +48,27 @@ class ImplicitronRayBundle:
camera_counts: A tensor of shape (N, ) which how many times the
coresponding camera in `camera_ids` was sampled.
`sum(camera_counts) == minibatch`, where `minibatch = origins.shape[0]`.
Attributes:
origins: A tensor of shape `(..., 3)` denoting the
origins of the sampling rays in world coords.
directions: A tensor of shape `(..., 3)` containing the direction
vectors of sampling rays in world coords. They don't have to be normalized;
they define unit vectors in the respective 1D coordinate systems; see
documentation for :func:`ray_bundle_to_ray_points` for the conversion formula.
lengths: A tensor of shape `(..., num_points_per_ray)`
containing the lengths at which the rays are sampled.
xys: A tensor of shape `(..., 2)`, the xy-locations (`xys`) of the ray pixels
camera_ids: An optional tensor of shape (N, ) which indicates which camera
was used to sample the rays. `N` is the number of unique sampled cameras.
camera_counts: An optional tensor of shape (N, ) indicates how many times the
coresponding camera in `camera_ids` was sampled.
`sum(camera_counts)==total_number_of_rays`.
bins: An optional tensor of shape `(..., num_points_per_ray + 1)`
containing the bins at which the rays are sampled. In this case
lengths should be equal to the midpoints of bins `(..., num_points_per_ray)`.
pixel_radii_2d: An optional tensor of shape `(..., 1)`
base radii of the conical frustums.
"""

origins: torch.Tensor
Expand All @@ -55,6 +77,45 @@ class ImplicitronRayBundle:
xys: torch.Tensor
camera_ids: Optional[torch.LongTensor] = None
camera_counts: Optional[torch.LongTensor] = None
bins: Optional[torch.Tensor] = None
pixel_radii_2d: Optional[torch.Tensor] = None

@classmethod
def from_bins(
cls,
origins: torch.Tensor,
directions: torch.Tensor,
bins: torch.Tensor,
xys: torch.Tensor,
**kwargs,
) -> "ImplicitronRayBundle":
"""
Creates a new instance from bins instead of lengths.
Attributes:
origins: A tensor of shape `(..., 3)` denoting the
origins of the sampling rays in world coords.
directions: A tensor of shape `(..., 3)` containing the direction
vectors of sampling rays in world coords. They don't have to be normalized;
they define unit vectors in the respective 1D coordinate systems; see
documentation for :func:`ray_bundle_to_ray_points` for the conversion formula.
bins: A tensor of shape `(..., num_points_per_ray + 1)`
containing the bins at which the rays are sampled. In this case
lengths is equal to the midpoints of bins `(..., num_points_per_ray)`.
xys: A tensor of shape `(..., 2)`, the xy-locations (`xys`) of the ray pixels
kwargs: Additional arguments passed to the constructor of ImplicitronRayBundle
Returns:
An instance of ImplicitronRayBundle.
"""

if bins.shape[-1] <= 1:
raise ValueError(
"The last dim of bins must be at least superior or equal to 2."
)
# equivalent to: 0.5 * (bins[..., 1:] + bins[..., :-1]) but more efficient
lengths = torch.lerp(bins[..., 1:], bins[..., :-1], 0.5)

return cls(origins, directions, lengths, xys, bins=bins, **kwargs)

def is_packed(self) -> bool:
"""
Expand Down Expand Up @@ -195,3 +256,154 @@ def forward(
instance of RendererOutput
"""
pass


def compute_3d_diagonal_covariance_gaussian(
rays_directions: torch.Tensor,
rays_dir_variance: torch.Tensor,
radii_variance: torch.Tensor,
eps: float = 1e-6,
) -> torch.Tensor:
"""
Transform the variances (rays_dir_variance, radii_variance) of the gaussians from
the coordinate frame of the conical frustum to 3D world coordinates.
It follows the equation 16 of `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_
Args:
rays_directions: A tensor of shape `(..., 3)`
rays_dir_variance: A tensor of shape `(..., num_intervals)` representing
the variance of the conical frustum with respect to the rays direction.
radii_variance: A tensor of shape `(..., num_intervals)` representing
the variance of the conical frustum with respect to its radius.
eps: a small number to prevent division by zero.
Returns:
A tensor of shape `(..., num_intervals, 3)` containing the diagonal
of the covariance matrix.
"""
d_outer_diag = torch.pow(rays_directions, 2)
dir_mag_sq = torch.clamp(torch.sum(d_outer_diag, dim=-1, keepdim=True), min=eps)

null_outer_diag = 1 - d_outer_diag / dir_mag_sq
ray_dir_cov_diag = rays_dir_variance[..., None] * d_outer_diag[..., None, :]
xy_cov_diag = radii_variance[..., None] * null_outer_diag[..., None, :]
return ray_dir_cov_diag + xy_cov_diag


def approximate_conical_frustum_as_gaussians(
bins: torch.Tensor, radii: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Approximates a conical frustum as two Gaussian distributions.
The Gaussian distributions are characterized by
three values:
- rays_dir_mean: mean along the rays direction
(defined as t in the parametric representation of a cone).
- rays_dir_variance: the variance of the conical frustum along the rays direction.
- radii_variance: variance of the conical frustum with respect to its radius.
The computation is stable and follows equation 7
of `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
For more information on how the mean and variances are computed
refers to the appendix of the paper.
Args:
bins: A tensor of shape `(..., num_points_per_ray + 1)`
containing the bins at which the rays are sampled.
`bin[..., t]` and `bin[..., t+1]` represent respectively
the left and right coordinates of the interval.
t0: A tensor of shape `(..., num_points_per_ray)`
containing the left coordinates of the intervals
on which the rays are sampled.
t1: A tensor of shape `(..., num_points_per_ray)`
containing the rights coordinates of the intervals
on which the rays are sampled.
radii: A tensor of shape `(..., 1)`
base radii of the conical frustums.
Returns:
rays_dir_mean: A tensor of shape `(..., num_intervals)` representing
the mean along the rays direction
(t in the parametric represention of the cone)
rays_dir_variance: A tensor of shape `(..., num_intervals)` representing
the variance of the conical frustum along the rays
(t in the parametric represention of the cone).
radii_variance: A tensor of shape `(..., num_intervals)` representing
the variance of the conical frustum with respect to its radius.
"""
t_mu = torch.lerp(bins[..., 1:], bins[..., :-1], 0.5)
t_delta = torch.diff(bins, dim=-1) / 2

t_mu_pow2 = torch.pow(t_mu, 2)
t_delta_pow2 = torch.pow(t_delta, 2)
t_delta_pow4 = torch.pow(t_delta, 4)

den = 3 * t_mu_pow2 + t_delta_pow2

# mean along the rays direction
rays_dir_mean = t_mu + 2 * t_mu * t_delta_pow2 / den

# Variance of the conical frustum with along the rays directions
rays_dir_variance = t_delta_pow2 / 3 - (4 / 15) * (
t_delta_pow4 * (12 * t_mu_pow2 - t_delta_pow2) / torch.pow(den, 2)
)

# Variance of the conical frustum with respect to its radius
radii_variance = torch.pow(radii, 2) * (
t_mu_pow2 / 4 + (5 / 12) * t_delta_pow2 - 4 / 15 * (t_delta_pow4) / den
)
return rays_dir_mean, rays_dir_variance, radii_variance


def conical_frustum_to_gaussian(
ray_bundle: ImplicitronRayBundle,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Approximate a conical frustum following a ray bundle as a Gaussian.
Args:
ray_bundle: A `RayBundle` or `HeterogeneousRayBundle` object with fields:
origins: A tensor of shape `(..., 3)`
directions: A tensor of shape `(..., 3)`
lengths: A tensor of shape `(..., num_points_per_ray)`
bins: A tensor of shape `(..., num_points_per_ray + 1)`
containing the bins at which the rays are sampled. .
pixel_radii_2d: A tensor of shape `(..., 1)`
base radii of the conical frustums.
Returns:
means: A tensor of shape `(..., num_points_per_ray - 1, 3)`
representing the means of the Gaussians
approximating the conical frustums.
diag_covariances: A tensor of shape `(...,num_points_per_ray -1, 3)`
representing the diagonal covariance matrices of our Gaussians.
"""

if ray_bundle.pixel_radii_2d is None or ray_bundle.bins is None:
raise ValueError(
"RayBundle pixel_radii_2d or bins have not been provided."
" Look at pytorch3d.renderer.implicit.renderer.ray_sampler::"
"AbstractMaskRaySampler to see how to compute them. Have you forgot to set"
"`cast_ray_bundle_as_cone` to True?"
)

(
rays_dir_mean,
rays_dir_variance,
radii_variance,
) = approximate_conical_frustum_as_gaussians(
ray_bundle.bins,
ray_bundle.pixel_radii_2d,
)
means = ray_bundle_variables_to_ray_points(
ray_bundle.origins, ray_bundle.directions, rays_dir_mean
)
diag_covariances = compute_3d_diagonal_covariance_gaussian(
ray_bundle.directions, rays_dir_variance, radii_variance
)
return means, diag_covariances
Loading

0 comments on commit 29b8ebd

Please sign in to comment.