Skip to content

Commit

Permalink
Heterogeneous raysampling -> RayBundleHeterogeneous
Browse files Browse the repository at this point in the history
Summary:
Added heterogeneous raysampling to pytorch3d raysampler, different cameras are sampled different number of times.

 It now returns RayBundle if heterogeneous raysampling is off and new RayBundleHeterogeneous (with added fields `camera_ids` and `camera_counts`).  Heterogeneous raysampling is on if `n_rays_total` is not None.

Reviewed By: bottler

Differential Revision: D39542222

fbshipit-source-id: d3d88d822ec7696e856007c088dc36a1cfa8c625
  • Loading branch information
Darijan Gudelj authored and facebook-github-bot committed Sep 30, 2022
1 parent 9a0f9ae commit 6ae863f
Show file tree
Hide file tree
Showing 6 changed files with 325 additions and 48 deletions.
1 change: 1 addition & 0 deletions pytorch3d/renderer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
EmissionAbsorptionRaymarcher,
GridRaysampler,
HarmonicEmbedding,
HeterogeneousRayBundle,
ImplicitRenderer,
MonteCarloRaysampler,
MultinomialRaysampler,
Expand Down
1 change: 1 addition & 0 deletions pytorch3d/renderer/implicit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from .renderer import ImplicitRenderer, VolumeRenderer, VolumeSampler
from .utils import (
HeterogeneousRayBundle,
ray_bundle_to_ray_points,
ray_bundle_variables_to_ray_points,
RayBundle,
Expand Down
208 changes: 189 additions & 19 deletions pytorch3d/renderer/implicit/raysampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
# LICENSE file in the root directory of this source tree.

import warnings
from typing import Optional
from typing import Optional, Tuple, Union

import torch
from pytorch3d.common.compat import meshgrid_ij
from pytorch3d.ops import padded_to_packed
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit.utils import RayBundle
from pytorch3d.renderer.implicit.utils import HeterogeneousRayBundle, RayBundle
from torch.nn import functional as F


Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(
min_depth: float,
max_depth: float,
n_rays_per_image: Optional[int] = None,
n_rays_total: Optional[int] = None,
unit_directions: bool = False,
stratified_sampling: bool = False,
) -> None:
Expand All @@ -88,6 +90,11 @@ def __init__(
min_depth: The minimum depth of a ray-point.
max_depth: The maximum depth of a ray-point.
n_rays_per_image: If given, this amount of rays are sampled from the grid.
n_rays_total: How many rays in total to sample from the cameras provided. The result
is as if `n_rays_total` cameras were sampled with replacement from the
cameras provided and for every camera one ray was sampled. If set, this disables
`n_rays_per_image` and returns the HeterogeneousRayBundle with
batch_size=n_rays_total.
unit_directions: whether to normalize direction vectors in ray bundle.
stratified_sampling: if True, performs stratified random sampling
along the ray; otherwise takes ray points at deterministic offsets.
Expand All @@ -97,6 +104,7 @@ def __init__(
self._min_depth = min_depth
self._max_depth = max_depth
self._n_rays_per_image = n_rays_per_image
self._n_rays_total = n_rays_total
self._unit_directions = unit_directions
self._stratified_sampling = stratified_sampling

Expand Down Expand Up @@ -125,8 +133,9 @@ def forward(
n_rays_per_image: Optional[int] = None,
n_pts_per_ray: Optional[int] = None,
stratified_sampling: Optional[bool] = None,
n_rays_total: Optional[int] = None,
**kwargs,
) -> RayBundle:
) -> Union[RayBundle, HeterogeneousRayBundle]:
"""
Args:
cameras: A batch of `batch_size` cameras from which the rays are emitted.
Expand All @@ -138,8 +147,15 @@ def forward(
n_pts_per_ray: The number of points sampled along each ray.
stratified_sampling: if set, overrides stratified_sampling provided
in __init__.
n_rays_total: How many rays in total to sample from the cameras provided. The result
is as if `n_rays_total_training` cameras were sampled with replacement from the
cameras provided and for every camera one ray was sampled. If set, this disables
`n_rays_per_image` and returns the HeterogeneousRayBundle with
batch_size=n_rays_total.
Returns:
A named tuple RayBundle with the following fields:
A named tuple RayBundle or dataclass HeterogeneousRayBundle with the
following fields:
origins: A tensor of shape
`(batch_size, s1, s2, 3)`
denoting the locations of ray origins in the world coordinates.
Expand All @@ -153,23 +169,56 @@ def forward(
`(batch_size, s1, s2, 2)`
containing the 2D image coordinates of each ray or,
if mask is given, `(batch_size, n, 1, 2)`
Here `s1, s2` refer to spatial dimensions. Unless the mask is
given, they equal `(image_height, image_width)`, otherwise `(n, 1)`,
where `n` is `n_rays_per_image` if provided, otherwise the minimum
cardinality of the mask in the batch.
Here `s1, s2` refer to spatial dimensions.
`(s1, s2)` refer to (highest priority first):
- `(1, 1)` if `n_rays_total` is provided, (batch_size=n_rays_total)
- `(n_rays_per_image, 1) if `n_rays_per_image` if provided,
- `(n, 1)` where n is the minimum cardinality of the mask
in the batch if `mask` is provided
- `(image_height, image_width)` if nothing from above is satisfied
`HeterogeneousRayBundle` has additional members:
- camera_ids: tensor of shape (M,), where `M` is the number of unique sampled
cameras. It represents unique ids of sampled cameras.
- camera_counts: tensor of shape (M,), where `M` is the number of unique sampled
cameras. Represents how many times each camera from `camera_ids` was sampled
`HeterogeneousRayBundle` is returned if `n_rays_total` is provided else `RayBundle`
is returned.
"""
n_rays_total = n_rays_total or self._n_rays_total
n_rays_per_image = n_rays_per_image or self._n_rays_per_image
assert (n_rays_total is None) or (
n_rays_per_image is None
), "`n_rays_total` and `n_rays_per_image` cannot both be defined."
if n_rays_total:
(
cameras,
mask,
camera_ids, # unique ids of sampled cameras
camera_counts, # number of times unique camera id was sampled
# `n_rays_per_image` is equal to the max number of times a simgle camera
# was sampled. We sample all cameras at `camera_ids` `n_rays_per_image` times
# and then discard the unneeded rays.
# pyre-ignore[9]
n_rays_per_image,
) = _sample_cameras_and_masks(n_rays_total, cameras, mask)
else:
camera_ids = torch.range(0, len(cameras), dtype=torch.long)

batch_size = cameras.R.shape[0]
device = cameras.device

# expand the (H, W, 2) grid batch_size-times to (B, H, W, 2)
xy_grid = self._xy_grid.to(device).expand(batch_size, -1, -1, -1)

num_rays = n_rays_per_image or self._n_rays_per_image
if mask is not None and num_rays is None:
if mask is not None and n_rays_per_image is None:
# if num rays not given, sample according to the smallest mask
num_rays = num_rays or mask.sum(dim=(1, 2)).min().int().item()
n_rays_per_image = (
n_rays_per_image or mask.sum(dim=(1, 2)).min().int().item()
)

if num_rays is not None:
if n_rays_per_image is not None:
if mask is not None:
assert mask.shape == xy_grid.shape[:3]
weights = mask.reshape(batch_size, -1)
Expand All @@ -181,7 +230,9 @@ def forward(
weights = xy_grid.new_ones(batch_size, width * height)
# pyre-fixme[6]: For 2nd param expected `int` but got `Union[bool,
# float, int]`.
rays_idx = _safe_multinomial(weights, num_rays)[..., None].expand(-1, -1, 2)
rays_idx = _safe_multinomial(weights, n_rays_per_image)[..., None].expand(
-1, -1, 2
)

xy_grid = torch.gather(xy_grid.reshape(batch_size, -1, 2), 1, rays_idx)[
:, :, None
Expand All @@ -198,7 +249,7 @@ def forward(
else self._stratified_sampling
)

return _xy_to_ray_bundle(
ray_bundle = _xy_to_ray_bundle(
cameras,
xy_grid,
min_depth,
Expand All @@ -208,6 +259,13 @@ def forward(
stratified_sampling,
)

return (
# pyre-ignore[61]
_pack_ray_bundle(ray_bundle, camera_ids, camera_counts)
if n_rays_total
else ray_bundle
)


class NDCMultinomialRaysampler(MultinomialRaysampler):
"""
Expand All @@ -231,6 +289,7 @@ def __init__(
min_depth: float,
max_depth: float,
n_rays_per_image: Optional[int] = None,
n_rays_total: Optional[int] = None,
unit_directions: bool = False,
stratified_sampling: bool = False,
) -> None:
Expand All @@ -254,6 +313,7 @@ def __init__(
min_depth=min_depth,
max_depth=max_depth,
n_rays_per_image=n_rays_per_image,
n_rays_total=n_rays_total,
unit_directions=unit_directions,
stratified_sampling=stratified_sampling,
)
Expand Down Expand Up @@ -281,6 +341,7 @@ def __init__(
min_depth: float,
max_depth: float,
*,
n_rays_total: Optional[int] = None,
unit_directions: bool = False,
stratified_sampling: bool = False,
) -> None:
Expand All @@ -294,6 +355,11 @@ def __init__(
n_pts_per_ray: The number of points sampled along each ray.
min_depth: The minimum depth of each ray-point.
max_depth: The maximum depth of each ray-point.
n_rays_total: How many rays in total to sample from the cameras provided. The result
is as if `n_rays_total_training` cameras were sampled with replacement from the
cameras provided and for every camera one ray was sampled. If set, this disables
`n_rays_per_image` and returns the HeterogeneousRayBundleyBundle with
batch_size=n_rays_total.
unit_directions: whether to normalize direction vectors in ray bundle.
stratified_sampling: if True, performs stratified sampling in n_pts_per_ray
bins for each ray; otherwise takes n_pts_per_ray deterministic points
Expand All @@ -308,6 +374,7 @@ def __init__(
self._n_pts_per_ray = n_pts_per_ray
self._min_depth = min_depth
self._max_depth = max_depth
self._n_rays_total = n_rays_total
self._unit_directions = unit_directions
self._stratified_sampling = stratified_sampling

Expand All @@ -317,15 +384,16 @@ def forward(
*,
stratified_sampling: Optional[bool] = None,
**kwargs,
) -> RayBundle:
) -> Union[RayBundle, HeterogeneousRayBundle]:
"""
Args:
cameras: A batch of `batch_size` cameras from which the rays are emitted.
stratified_sampling: if set, overrides stratified_sampling provided
in __init__.
Returns:
A named tuple RayBundle with the following fields:
A named tuple `RayBundle` or dataclass `HeterogeneousRayBundle` with the
following fields:
origins: A tensor of shape
`(batch_size, n_rays_per_image, 3)`
denoting the locations of ray origins in the world coordinates.
Expand All @@ -338,7 +406,31 @@ def forward(
xys: A tensor of shape
`(batch_size, n_rays_per_image, 2)`
containing the 2D image coordinates of each ray.
If `n_rays_total` is provided `batch_size=n_rays_total`and
`n_rays_per_image=1` and `HeterogeneousRayBundle` is returned else `RayBundle`
is returned.
`HeterogeneousRayBundle` has additional members:
- camera_ids: tensor of shape (M,), where `M` is the number of unique sampled
cameras. It represents unique ids of sampled cameras.
- camera_counts: tensor of shape (M,), where `M` is the number of unique sampled
cameras. Represents how many times each camera from `camera_ids` was sampled
"""
assert (self._n_rays_total is None) or (
self._n_rays_per_image is None
), "`self.n_rays_total` and `self.n_rays_per_image` cannot both be defined."

if self._n_rays_total:
(
cameras,
_,
camera_ids,
camera_counts,
n_rays_per_image,
) = _sample_cameras_and_masks(self._n_rays_total, cameras, None)
else:
camera_ids = torch.range(0, len(cameras), dtype=torch.long)
n_rays_per_image = self._n_rays_per_image

batch_size = cameras.R.shape[0]

Expand All @@ -349,7 +441,7 @@ def forward(
rays_xy = torch.cat(
[
torch.rand(
size=(batch_size, self._n_rays_per_image, 1),
size=(batch_size, n_rays_per_image, 1),
dtype=torch.float32,
device=device,
)
Expand All @@ -369,7 +461,7 @@ def forward(
else self._stratified_sampling
)

return _xy_to_ray_bundle(
ray_bundle = _xy_to_ray_bundle(
cameras,
rays_xy,
self._min_depth,
Expand All @@ -379,6 +471,13 @@ def forward(
stratified_sampling,
)

return (
# pyre-ignore[61]
_pack_ray_bundle(ray_bundle, camera_ids, camera_counts)
if self._n_rays_total
else ray_bundle
)


# Settings for backwards compatibility
def GridRaysampler(
Expand Down Expand Up @@ -602,3 +701,74 @@ def _jiggle_within_stratas(bin_centers: torch.Tensor) -> torch.Tensor:
# Samples in those intervals.
jiggled = lower + (upper - lower) * torch.rand_like(lower)
return jiggled


def _sample_cameras_and_masks(
n_samples: int, cameras: CamerasBase, mask: Optional[torch.Tensor] = None
) -> Tuple[
CamerasBase, Optional[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor
]:
"""
Samples n_rays_total cameras and masks and returns them in a form
(camera_idx, count), where count represents number of times the same camera
has been sampled.
Args:
n_samples: how many camera and mask pairs to sample
cameras: A batch of `batch_size` cameras from which the rays are emitted.
mask: Optional. Should be of size (batch_size, image_height, image_width).
Returns:
tuple of a form (sampled_cameras, sampled_masks, unique_sampled_camera_ids,
number_of_times_each_sampled_camera_has_been_sampled,
max_number_of_times_camera_has_been_sampled,
)
"""
sampled_ids = torch.randint(
0,
len(cameras),
size=(n_samples,),
dtype=torch.long,
)
unique_ids, counts = torch.unique(sampled_ids, return_counts=True)
return (
cameras[unique_ids],
mask[unique_ids] if mask is not None else None,
unique_ids,
counts,
torch.max(counts),
)


def _pack_ray_bundle(
ray_bundle: RayBundle, camera_ids: torch.Tensor, camera_counts: torch.Tensor
) -> HeterogeneousRayBundle:
"""
Pack the raybundle from [n_cameras, max(rays_per_camera), ...] to
[total_num_rays, 1, ...]
Args:
ray_bundle: A ray_bundle to pack
camera_ids: Unique ids of cameras that were sampled
camera_counts: how many of which camera to pack, each count coresponds to
one 'row' of the ray_bundle and says how many rays wll be taken
from it and packed.
Returns:
HeterogeneousRayBundle where batch_size=sum(camera_counts) and n_rays_per_image=1
"""
camera_counts = camera_counts.to(ray_bundle.origins.device)
cumsum = torch.cumsum(camera_counts, dim=0, dtype=torch.long)
first_idxs = torch.cat(
(camera_counts.new_zeros((1,), dtype=torch.long), cumsum[:-1])
)
num_inputs = int(camera_counts.sum())

return HeterogeneousRayBundle(
origins=padded_to_packed(ray_bundle.origins, first_idxs, num_inputs)[:, None],
directions=padded_to_packed(ray_bundle.directions, first_idxs, num_inputs)[
:, None
],
lengths=padded_to_packed(ray_bundle.lengths, first_idxs, num_inputs)[:, None],
xys=padded_to_packed(ray_bundle.xys, first_idxs, num_inputs)[:, None],
camera_ids=camera_ids,
camera_counts=camera_counts,
)
Loading

0 comments on commit 6ae863f

Please sign in to comment.