From ad8907d3738fbf4c80aa269954d1d8ba4f307530 Mon Sep 17 00:00:00 2001 From: Darijan Gudelj Date: Mon, 3 Oct 2022 08:36:47 -0700 Subject: [PATCH] ImplicitronRayBundle Summary: new implicitronRayBundle with added cameraIDs and camera counts. Added to enable a single raybundle inside Implicitron and easier extension in the future. Since RayBundle is named tuple and RayBundleHeterogeneous is dataclass and RayBundleHeterogeneous cannot inherit RayBundle. So if there was no ImplicitronRayBundle every function that uses RayBundle now would have to use Union[RayBundle, RaybundleHeterogeneous] which is confusing and unecessary complicated. Reviewed By: bottler, kjchalup Differential Revision: D39262999 fbshipit-source-id: ece160e32f6c88c3977e408e966789bf8307af59 --- docs/tutorials/implicitron_volumes.ipynb | 5 +- pytorch3d/implicitron/models/generic_model.py | 14 +- .../models/implicit_function/base.py | 5 +- .../implicit_function/idr_feature_field.py | 6 +- .../neural_radiance_field.py | 10 +- .../scene_representation_networks.py | 20 +-- .../models/implicit_function/utils.py | 7 +- pytorch3d/implicitron/models/renderer/base.py | 44 +++++- .../models/renderer/lstm_renderer.py | 20 +-- .../models/renderer/multipass_ea.py | 6 +- .../models/renderer/ray_sampler.py | 26 +++- .../implicitron/models/renderer/rgb_net.py | 6 +- .../models/renderer/sdf_renderer.py | 6 +- pytorch3d/renderer/implicit/raysampling.py | 28 ++-- pytorch3d/renderer/implicit/renderer.py | 1 + pytorch3d/vis/plotly_vis.py | 129 ++++++++++++++---- tests/implicitron/test_ray_point_refiner.py | 11 +- tests/implicitron/test_srn.py | 15 +- 18 files changed, 259 insertions(+), 100 deletions(-) diff --git a/docs/tutorials/implicitron_volumes.ipynb b/docs/tutorials/implicitron_volumes.ipynb index 605edae61..1af8af1a4 100644 --- a/docs/tutorials/implicitron_volumes.ipynb +++ b/docs/tutorials/implicitron_volumes.ipynb @@ -145,10 +145,9 @@ "from pytorch3d.implicitron.dataset.dataset_base import FrameData\n", "from pytorch3d.implicitron.dataset.rendered_mesh_dataset_map_provider import RenderedMeshDatasetMapProvider\n", "from pytorch3d.implicitron.models.generic_model import GenericModel\n", - "from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase\n", + "from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase, ImplicitronRayBundle\n", "from pytorch3d.implicitron.models.renderer.base import EvaluationMode\n", "from pytorch3d.implicitron.tools.config import get_default_args, registry, remove_unused_components\n", - "from pytorch3d.renderer import RayBundle\n", "from pytorch3d.renderer.implicit.renderer import VolumeSampler\n", "from pytorch3d.structures import Volumes\n", "from pytorch3d.vis.plotly_vis import plot_batch_individually, plot_scene" @@ -393,7 +392,7 @@ "\n", " def forward(\n", " self,\n", - " ray_bundle: RayBundle,\n", + " ray_bundle: ImplicitronRayBundle,\n", " fun_viewpool=None,\n", " global_code=None,\n", " ):\n", diff --git a/pytorch3d/implicitron/models/generic_model.py b/pytorch3d/implicitron/models/generic_model.py index b3dabee23..853e84ef8 100644 --- a/pytorch3d/implicitron/models/generic_model.py +++ b/pytorch3d/implicitron/models/generic_model.py @@ -22,6 +22,7 @@ RegularizationMetricsBase, ViewMetricsBase, ) +from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle from pytorch3d.implicitron.tools import image_utils, vis_utils from pytorch3d.implicitron.tools.config import ( expand_args_fields, @@ -30,7 +31,8 @@ ) from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples from pytorch3d.implicitron.tools.utils import cat_dataclass -from pytorch3d.renderer import RayBundle, utils as rend_utils +from pytorch3d.renderer import utils as rend_utils + from pytorch3d.renderer.cameras import CamerasBase from visdom import Visdom @@ -387,7 +389,7 @@ def safe_slice_targets( ) # (1) Sample rendering rays with the ray sampler. - ray_bundle: RayBundle = self.raysampler( # pyre-fixme[29] + ray_bundle: ImplicitronRayBundle = self.raysampler( # pyre-fixme[29] target_cameras, evaluation_mode, mask=mask_crop[:n_targets] @@ -568,14 +570,14 @@ def visualize( def _render( self, *, - ray_bundle: RayBundle, + ray_bundle: ImplicitronRayBundle, chunked_inputs: Dict[str, torch.Tensor], sampling_mode: RenderSamplingMode, **kwargs, ) -> RendererOutput: """ Args: - ray_bundle: A `RayBundle` object containing the parametrizations of the + ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the sampled rendering rays. chunked_inputs: A collection of tensor of shape `(B, _, H, W)`. E.g. SignedDistanceFunctionRenderer requires "object_mask", shape @@ -899,7 +901,7 @@ def _tensor_collator(batch, new_dims) -> torch.Tensor: def _chunk_generator( chunk_size: int, - ray_bundle: RayBundle, + ray_bundle: ImplicitronRayBundle, chunked_inputs: Dict[str, torch.Tensor], tqdm_trigger_threshold: int, *args, @@ -932,7 +934,7 @@ def _chunk_generator( for start_idx in iter: end_idx = min(start_idx + chunk_size_in_rays, n_rays) - ray_bundle_chunk = RayBundle( + ray_bundle_chunk = ImplicitronRayBundle( origins=ray_bundle.origins.reshape(batch_size, -1, 3)[:, start_idx:end_idx], directions=ray_bundle.directions.reshape(batch_size, -1, 3)[ :, start_idx:end_idx diff --git a/pytorch3d/implicitron/models/implicit_function/base.py b/pytorch3d/implicitron/models/implicit_function/base.py index 2e0c77984..75bd36538 100644 --- a/pytorch3d/implicitron/models/implicit_function/base.py +++ b/pytorch3d/implicitron/models/implicit_function/base.py @@ -7,9 +7,10 @@ from abc import ABC, abstractmethod from typing import Optional +from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle + from pytorch3d.implicitron.tools.config import ReplaceableBase from pytorch3d.renderer.cameras import CamerasBase -from pytorch3d.renderer.implicit import RayBundle class ImplicitFunctionBase(ABC, ReplaceableBase): @@ -20,7 +21,7 @@ def __init__(self): def forward( self, *, - ray_bundle: RayBundle, + ray_bundle: ImplicitronRayBundle, fun_viewpool=None, camera: Optional[CamerasBase] = None, global_code=None, diff --git a/pytorch3d/implicitron/models/implicit_function/idr_feature_field.py b/pytorch3d/implicitron/models/implicit_function/idr_feature_field.py index 557ba1387..f43a2932e 100644 --- a/pytorch3d/implicitron/models/implicit_function/idr_feature_field.py +++ b/pytorch3d/implicitron/models/implicit_function/idr_feature_field.py @@ -6,8 +6,10 @@ from typing import Optional, Tuple import torch +from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle from pytorch3d.implicitron.tools.config import registry -from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle +from pytorch3d.renderer.implicit import HarmonicEmbedding + from torch import nn from .base import ImplicitFunctionBase @@ -127,7 +129,7 @@ def __post_init__(self): def forward( self, *, - ray_bundle: Optional[RayBundle] = None, + ray_bundle: Optional[ImplicitronRayBundle] = None, rays_points_world: Optional[torch.Tensor] = None, fun_viewpool=None, global_code=None, diff --git a/pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py b/pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py index d325c798c..aecd91051 100644 --- a/pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py +++ b/pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py @@ -9,8 +9,9 @@ import torch from pytorch3d.common.linear_with_repeat import LinearWithRepeat +from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle from pytorch3d.implicitron.tools.config import expand_args_fields, registry -from pytorch3d.renderer import ray_bundle_to_ray_points, RayBundle +from pytorch3d.renderer import ray_bundle_to_ray_points from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.implicit import HarmonicEmbedding @@ -130,7 +131,7 @@ def allows_multiple_passes() -> bool: def forward( self, *, - ray_bundle: RayBundle, + ray_bundle: ImplicitronRayBundle, fun_viewpool=None, camera: Optional[CamerasBase] = None, global_code=None, @@ -144,7 +145,7 @@ def forward( RGB color and opacity respectively. Args: - ray_bundle: A RayBundle object containing the following variables: + ray_bundle: An ImplicitronRayBundle object containing the following variables: origins: A tensor of shape `(minibatch, ..., 3)` denoting the origins of the sampling rays in world coords. directions: A tensor of shape `(minibatch, ..., 3)` @@ -165,11 +166,12 @@ def forward( """ # We first convert the ray parametrizations to world # coordinates with `ray_bundle_to_ray_points`. + # pyre-ignore[6] rays_points_world = ray_bundle_to_ray_points(ray_bundle) # rays_points_world.shape = [minibatch x ... x pts_per_ray x 3] embeds = create_embeddings_for_implicit_function( - xyz_world=ray_bundle_to_ray_points(ray_bundle), + xyz_world=rays_points_world, # pyre-fixme[6]: Expected `Optional[typing.Callable[..., typing.Any]]` # for 2nd param but got `Union[None, torch.Tensor, torch.nn.Module]`. xyz_embedding_function=self.harmonic_embedding_xyz diff --git a/pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py b/pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py index c701c54c0..b9e3cc1e5 100644 --- a/pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py +++ b/pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py @@ -6,9 +6,10 @@ import torch from omegaconf import DictConfig from pytorch3d.common.linear_with_repeat import LinearWithRepeat +from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle from pytorch3d.implicitron.third_party import hyperlayers, pytorch_prototyping from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation -from pytorch3d.renderer import ray_bundle_to_ray_points, RayBundle +from pytorch3d.renderer import ray_bundle_to_ray_points from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.implicit import HarmonicEmbedding @@ -68,7 +69,7 @@ def __post_init__(self): def forward( self, - ray_bundle: RayBundle, + ray_bundle: ImplicitronRayBundle, fun_viewpool=None, camera: Optional[CamerasBase] = None, global_code=None, @@ -76,7 +77,7 @@ def forward( ): """ Args: - ray_bundle: A RayBundle object containing the following variables: + ray_bundle: An ImplicitronRayBundle object containing the following variables: origins: A tensor of shape `(minibatch, ..., 3)` denoting the origins of the sampling rays in world coords. directions: A tensor of shape `(minibatch, ..., 3)` @@ -96,10 +97,11 @@ def forward( """ # We first convert the ray parametrizations to world # coordinates with `ray_bundle_to_ray_points`. + # pyre-ignore[6] rays_points_world = ray_bundle_to_ray_points(ray_bundle) embeds = create_embeddings_for_implicit_function( - xyz_world=ray_bundle_to_ray_points(ray_bundle), + xyz_world=rays_points_world, # pyre-fixme[6]: Expected `Optional[typing.Callable[..., typing.Any]]` # for 2nd param but got `Union[torch.Tensor, torch.nn.Module]`. xyz_embedding_function=self._harmonic_embedding, @@ -175,7 +177,7 @@ def _get_colors(self, features: torch.Tensor, rays_directions: torch.Tensor): def forward( self, raymarch_features: torch.Tensor, - ray_bundle: RayBundle, + ray_bundle: ImplicitronRayBundle, camera: Optional[CamerasBase] = None, **kwargs, ): @@ -183,7 +185,7 @@ def forward( Args: raymarch_features: Features from the raymarching network of shape `(minibatch, ..., self.in_features)` - ray_bundle: A RayBundle object containing the following variables: + ray_bundle: An ImplicitronRayBundle object containing the following variables: origins: A tensor of shape `(minibatch, ..., 3)` denoting the origins of the sampling rays in world coords. directions: A tensor of shape `(minibatch, ..., 3)` @@ -297,7 +299,7 @@ def _run_hypernet(self, global_code: torch.Tensor) -> Tuple[SRNRaymarchFunction] def forward( self, - ray_bundle: RayBundle, + ray_bundle: ImplicitronRayBundle, fun_viewpool=None, camera: Optional[CamerasBase] = None, global_code=None, @@ -350,7 +352,7 @@ def raymarch_function_tweak_args(cls, type, args: DictConfig) -> None: def forward( self, *, - ray_bundle: RayBundle, + ray_bundle: ImplicitronRayBundle, fun_viewpool=None, camera: Optional[CamerasBase] = None, global_code=None, @@ -410,7 +412,7 @@ def hypernet_tweak_args(cls, type, args: DictConfig) -> None: def forward( self, *, - ray_bundle: RayBundle, + ray_bundle: ImplicitronRayBundle, fun_viewpool=None, camera: Optional[CamerasBase] = None, global_code=None, diff --git a/pytorch3d/implicitron/models/implicit_function/utils.py b/pytorch3d/implicitron/models/implicit_function/utils.py index 9b401c489..e9b688efc 100644 --- a/pytorch3d/implicitron/models/implicit_function/utils.py +++ b/pytorch3d/implicitron/models/implicit_function/utils.py @@ -10,9 +10,9 @@ import torch.nn.functional as F from pytorch3d.common.compat import prod +from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle from pytorch3d.renderer import ray_bundle_to_ray_points from pytorch3d.renderer.cameras import CamerasBase -from pytorch3d.renderer.implicit import RayBundle def broadcast_global_code(embeds: torch.Tensor, global_code: torch.Tensor): @@ -190,7 +190,7 @@ def interpolate_volume( def get_rays_points_world( - ray_bundle: Optional[RayBundle] = None, + ray_bundle: Optional[ImplicitronRayBundle] = None, rays_points_world: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ @@ -198,7 +198,7 @@ def get_rays_points_world( and raises error if both are defined. Args: - ray_bundle: A RayBundle object or None + ray_bundle: An ImplicitronRayBundle object or None rays_points_world: A torch.Tensor representing ray points converted to world coordinates Returns: @@ -213,5 +213,6 @@ def get_rays_points_world( if rays_points_world is not None: return rays_points_world if ray_bundle is not None: + # pyre-ignore[6] return ray_bundle_to_ray_points(ray_bundle) raise ValueError("ray_bundle and rays_points_world cannot both be None") diff --git a/pytorch3d/implicitron/models/renderer/base.py b/pytorch3d/implicitron/models/renderer/base.py index f55059aac..27ee1787a 100644 --- a/pytorch3d/implicitron/models/renderer/base.py +++ b/pytorch3d/implicitron/models/renderer/base.py @@ -6,6 +6,8 @@ from __future__ import annotations +import dataclasses + from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum @@ -25,6 +27,38 @@ class RenderSamplingMode(Enum): FULL_GRID = "full_grid" +@dataclasses.dataclass +class ImplicitronRayBundle: + """ + Parametrizes points along projection rays by storing ray `origins`, + `directions` vectors and `lengths` at which the ray-points are sampled. + Furthermore, the xy-locations (`xys`) of the ray pixels are stored as well. + Note that `directions` don't have to be normalized; they define unit vectors + in the respective 1D coordinate systems; see documentation for + :func:`ray_bundle_to_ray_points` for the conversion formula. + + camera_ids: A tensor of shape (N, ) which indicates which camera + was used to sample the rays. `N` is the number of different + sampled cameras. + camera_counts: A tensor of shape (N, ) which how many times the + coresponding camera in `camera_ids` was sampled. + `sum(camera_counts)==minibatch` + """ + + origins: torch.Tensor + directions: torch.Tensor + lengths: torch.Tensor + xys: torch.Tensor + camera_ids: Optional[torch.Tensor] = None + camera_counts: Optional[torch.Tensor] = None + + def is_packed(self) -> bool: + """ + Returns whether the ImplicitronRayBundle carries data in packed state + """ + return self.camera_ids is not None and self.camera_counts is not None + + @dataclass class RendererOutput: """ @@ -85,7 +119,7 @@ def requires_object_mask(self) -> bool: @abstractmethod def forward( self, - ray_bundle, + ray_bundle: ImplicitronRayBundle, implicit_functions: List[ImplicitFunctionWrapper], evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, **kwargs, @@ -95,7 +129,7 @@ def forward( that returns an instance of RendererOutput. Args: - ray_bundle: A RayBundle object containing the following variables: + ray_bundle: An ImplicitronRayBundle object containing the following variables: origins: A tensor of shape (minibatch, ..., 3) denoting the origins of the rendering rays. directions: A tensor of shape (minibatch, ..., 3) @@ -108,6 +142,12 @@ def forward( xys: A tensor of shape (minibatch, ..., 2) containing the xy locations of each ray's pixel in the NDC screen space. + camera_ids: A tensor of shape (N, ) which indicates which camera + was used to sample the rays. `N` is the number of different + sampled cameras. + camera_counts: A tensor of shape (N, ) which how many times the + coresponding camera in `camera_ids` was sampled. + `sum(camera_counts)==minibatch` implicit_functions: List of ImplicitFunctionWrappers which define the implicit function methods to be used. Most Renderers only allow a single implicit function. Currently, only the diff --git a/pytorch3d/implicitron/models/renderer/lstm_renderer.py b/pytorch3d/implicitron/models/renderer/lstm_renderer.py index c5ce094f5..b24c253f5 100644 --- a/pytorch3d/implicitron/models/renderer/lstm_renderer.py +++ b/pytorch3d/implicitron/models/renderer/lstm_renderer.py @@ -4,12 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import dataclasses import logging from typing import List, Optional, Tuple import torch +from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle from pytorch3d.implicitron.tools.config import registry -from pytorch3d.renderer import RayBundle from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput @@ -71,7 +72,7 @@ def __post_init__(self): def forward( self, - ray_bundle: RayBundle, + ray_bundle: ImplicitronRayBundle, implicit_functions: List[ImplicitFunctionWrapper], evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, **kwargs, @@ -79,7 +80,7 @@ def forward( """ Args: - ray_bundle: A `RayBundle` object containing the parametrizations of the + ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the sampled rendering rays. implicit_functions: A single-element list of ImplicitFunctionWrappers which defines the implicit function to be used. @@ -102,9 +103,12 @@ def forward( ) # jitter the initial depths - ray_bundle_t = ray_bundle._replace( - lengths=ray_bundle.lengths - + torch.randn_like(ray_bundle.lengths) * self.init_depth_noise_std + ray_bundle_t = dataclasses.replace( + ray_bundle, + lengths=( + ray_bundle.lengths + + torch.randn_like(ray_bundle.lengths) * self.init_depth_noise_std + ), ) states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [None] @@ -112,9 +116,7 @@ def forward( raymarch_features = None for t in range(self.num_raymarch_steps + 1): # move signed_distance along each ray - ray_bundle_t = ray_bundle_t._replace( - lengths=ray_bundle_t.lengths + signed_distance - ) + ray_bundle_t.lengths += signed_distance # eval the raymarching function raymarch_features, _ = implicit_function( diff --git a/pytorch3d/implicitron/models/renderer/multipass_ea.py b/pytorch3d/implicitron/models/renderer/multipass_ea.py index 61cf0d4c3..648e7f37b 100644 --- a/pytorch3d/implicitron/models/renderer/multipass_ea.py +++ b/pytorch3d/implicitron/models/renderer/multipass_ea.py @@ -7,8 +7,8 @@ from typing import List import torch +from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle from pytorch3d.implicitron.tools.config import registry, run_auto_creation -from pytorch3d.renderer import RayBundle from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput from .ray_point_refiner import RayPointRefiner @@ -107,14 +107,14 @@ def __post_init__(self): def forward( self, - ray_bundle: RayBundle, + ray_bundle: ImplicitronRayBundle, implicit_functions: List[ImplicitFunctionWrapper], evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, **kwargs, ) -> RendererOutput: """ Args: - ray_bundle: A `RayBundle` object containing the parametrizations of the + ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the sampled rendering rays. implicit_functions: List of ImplicitFunctionWrappers which define the implicit functions to be used sequentially in diff --git a/pytorch3d/implicitron/models/renderer/ray_sampler.py b/pytorch3d/implicitron/models/renderer/ray_sampler.py index 76f9f5bcb..6d3723ade 100644 --- a/pytorch3d/implicitron/models/renderer/ray_sampler.py +++ b/pytorch3d/implicitron/models/renderer/ray_sampler.py @@ -9,10 +9,10 @@ import torch from pytorch3d.implicitron.tools import camera_utils from pytorch3d.implicitron.tools.config import registry, ReplaceableBase -from pytorch3d.renderer import NDCMultinomialRaysampler, RayBundle +from pytorch3d.renderer import NDCMultinomialRaysampler from pytorch3d.renderer.cameras import CamerasBase -from .base import EvaluationMode, RenderSamplingMode +from .base import EvaluationMode, ImplicitronRayBundle, RenderSamplingMode class RaySamplerBase(ReplaceableBase): @@ -28,7 +28,7 @@ def forward( cameras: CamerasBase, evaluation_mode: EvaluationMode, mask: Optional[torch.Tensor] = None, - ) -> RayBundle: + ) -> ImplicitronRayBundle: """ Args: cameras: A batch of `batch_size` cameras from which the rays are emitted. @@ -42,7 +42,7 @@ def forward( corresponding pixel's ray. Returns: - ray_bundle: A `RayBundle` object containing the parametrizations of the + ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the sampled rendering rays. """ raise NotImplementedError() @@ -135,7 +135,7 @@ def forward( cameras: CamerasBase, evaluation_mode: EvaluationMode, mask: Optional[torch.Tensor] = None, - ) -> RayBundle: + ) -> ImplicitronRayBundle: """ Args: @@ -150,7 +150,7 @@ def forward( corresponding pixel's ray. Returns: - ray_bundle: A `RayBundle` object containing the parametrizations of the + ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the sampled rendering rays. """ sample_mask = None @@ -180,7 +180,19 @@ def forward( max_depth=max_depth, ) - return ray_bundle + if isinstance(ray_bundle, tuple): + return ImplicitronRayBundle( + # pyre-ignore[16] + **{k: v for k, v in ray_bundle._asdict().items()} + ) + return ImplicitronRayBundle( + directions=ray_bundle.directions, + origins=ray_bundle.origins, + lengths=ray_bundle.lengths, + xys=ray_bundle.xys, + camera_ids=ray_bundle.camera_ids, + camera_counts=ray_bundle.camera_counts, + ) @registry.register diff --git a/pytorch3d/implicitron/models/renderer/rgb_net.py b/pytorch3d/implicitron/models/renderer/rgb_net.py index 47609e83f..6d41d2165 100644 --- a/pytorch3d/implicitron/models/renderer/rgb_net.py +++ b/pytorch3d/implicitron/models/renderer/rgb_net.py @@ -7,8 +7,10 @@ from typing import List, Tuple import torch +from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle from pytorch3d.implicitron.tools.config import enable_get_default_args -from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle +from pytorch3d.renderer.implicit import HarmonicEmbedding + from torch import nn @@ -89,7 +91,7 @@ def forward( feature_vectors: torch.Tensor, points, normals, - ray_bundle: RayBundle, + ray_bundle: ImplicitronRayBundle, masks=None, pooling_fn=None, ): diff --git a/pytorch3d/implicitron/models/renderer/sdf_renderer.py b/pytorch3d/implicitron/models/renderer/sdf_renderer.py index 2f0e626c9..d8782911e 100644 --- a/pytorch3d/implicitron/models/renderer/sdf_renderer.py +++ b/pytorch3d/implicitron/models/renderer/sdf_renderer.py @@ -8,13 +8,13 @@ import torch from omegaconf import DictConfig from pytorch3d.common.compat import prod +from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle from pytorch3d.implicitron.tools.config import ( get_default_args_field, registry, run_auto_creation, ) from pytorch3d.implicitron.tools.utils import evaluating -from pytorch3d.renderer import RayBundle from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput from .ray_tracing import RayTracing @@ -69,7 +69,7 @@ def requires_object_mask(self) -> bool: def forward( self, - ray_bundle: RayBundle, + ray_bundle: ImplicitronRayBundle, implicit_functions: List[ImplicitFunctionWrapper], evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, object_mask: Optional[torch.Tensor] = None, @@ -77,7 +77,7 @@ def forward( ) -> RendererOutput: """ Args: - ray_bundle: A `RayBundle` object containing the parametrizations of the + ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the sampled rendering rays. implicit_functions: single element list of ImplicitFunctionWrappers which defines the implicit function to be used. diff --git a/pytorch3d/renderer/implicit/raysampling.py b/pytorch3d/renderer/implicit/raysampling.py index c53754e8f..033f783af 100644 --- a/pytorch3d/renderer/implicit/raysampling.py +++ b/pytorch3d/renderer/implicit/raysampling.py @@ -149,9 +149,8 @@ def forward( in __init__. n_rays_total: How many rays in total to sample from the cameras provided. The result is as if `n_rays_total_training` cameras were sampled with replacement from the - cameras provided and for every camera one ray was sampled. If set, this disables - `n_rays_per_image` and returns the HeterogeneousRayBundle with - batch_size=n_rays_total. + cameras provided and for every camera one ray was sampled. If set, returns the + HeterogeneousRayBundle with batch_size=n_rays_total. Returns: A named tuple RayBundle or dataclass HeterogeneousRayBundle with the following fields: @@ -188,9 +187,10 @@ def forward( """ n_rays_total = n_rays_total or self._n_rays_total n_rays_per_image = n_rays_per_image or self._n_rays_per_image - assert (n_rays_total is None) or ( - n_rays_per_image is None - ), "`n_rays_total` and `n_rays_per_image` cannot both be defined." + if (n_rays_total is not None) and (n_rays_per_image is not None): + raise ValueError( + "`n_rays_total` and `n_rays_per_image` cannot both be defined." + ) if n_rays_total: ( cameras, @@ -357,9 +357,8 @@ def __init__( max_depth: The maximum depth of each ray-point. n_rays_total: How many rays in total to sample from the cameras provided. The result is as if `n_rays_total_training` cameras were sampled with replacement from the - cameras provided and for every camera one ray was sampled. If set, this disables - `n_rays_per_image` and returns the HeterogeneousRayBundleyBundle with - batch_size=n_rays_total. + cameras provided and for every camera one ray was sampled. If set, this returns + the HeterogeneousRayBundleyBundle with batch_size=n_rays_total. unit_directions: whether to normalize direction vectors in ray bundle. stratified_sampling: if True, performs stratified sampling in n_pts_per_ray bins for each ray; otherwise takes n_pts_per_ray deterministic points @@ -416,9 +415,14 @@ def forward( - camera_counts: tensor of shape (M,), where `M` is the number of unique sampled cameras. Represents how many times each camera from `camera_ids` was sampled """ - assert (self._n_rays_total is None) or ( - self._n_rays_per_image is None - ), "`self.n_rays_total` and `self.n_rays_per_image` cannot both be defined." + if ( + sum(x is not None for x in [self._n_rays_total, self._n_rays_per_image]) + != 1 + ): + raise ValueError( + "Exactly one of `self.n_rays_total` and `self.n_rays_per_image` " + "must be given." + ) if self._n_rays_total: ( diff --git a/pytorch3d/renderer/implicit/renderer.py b/pytorch3d/renderer/implicit/renderer.py index c2be5adcb..56583cdbe 100644 --- a/pytorch3d/renderer/implicit/renderer.py +++ b/pytorch3d/renderer/implicit/renderer.py @@ -297,6 +297,7 @@ def forward( """ Given an input ray parametrization, the forward function samples `self._volumes` at the respective 3D ray-points. + Can also accept ImplicitronRayBundle as argument for ray_bundle. Args: ray_bundle: A RayBundle or HeterogeneousRayBundle object with the following fields: diff --git a/pytorch3d/vis/plotly_vis.py b/pytorch3d/vis/plotly_vis.py index 776f47688..1cb4985d1 100644 --- a/pytorch3d/vis/plotly_vis.py +++ b/pytorch3d/vis/plotly_vis.py @@ -11,6 +11,7 @@ import torch from plotly.subplots import make_subplots from pytorch3d.renderer import ( + HeterogeneousRayBundle, ray_bundle_to_ray_points, RayBundle, TexturesAtlas, @@ -21,14 +22,45 @@ from pytorch3d.structures import join_meshes_as_scene, Meshes, Pointclouds -Struct = Union[CamerasBase, Meshes, Pointclouds, RayBundle] +Struct = Union[CamerasBase, Meshes, Pointclouds, RayBundle, HeterogeneousRayBundle] -def _get_struct_len(struct: Struct) -> int: # pragma: no cover +def _get_len(struct: Union[Struct, List[Struct]]) -> int: # pragma: no cover """ Returns the length (usually corresponds to the batch size) of the input structure. """ - return len(struct.directions) if isinstance(struct, RayBundle) else len(struct) + # pyre-ignore[6] + if not _is_ray_bundle(struct): + # pyre-ignore[6] + return len(struct) + if _is_heterogeneous_ray_bundle(struct): + # pyre-ignore[16] + return len(struct.camera_counts) + # pyre-ignore[16] + return len(struct.directions) + + +def _is_ray_bundle(struct: Struct) -> bool: + """ + Args: + struct: Struct object to test + Returns: + True if something is a RayBundle, HeterogeneousRayBundle or + ImplicitronRayBundle, else False + """ + return hasattr(struct, "directions") + + +def _is_heterogeneous_ray_bundle(struct: Union[List[Struct], Struct]) -> bool: + """ + Args: + struct :object to test + Returns: + True if something is a HeterogeneousRayBundle or ImplicitronRayBundle + and cant be reduced to RayBundle else False + """ + # pyre-ignore[16] + return hasattr(struct, "camera_counts") and struct.camera_counts is not None def get_camera_wireframe(scale: float = 0.3): # pragma: no cover @@ -301,7 +333,7 @@ def plot_scene( _add_camera_trace( fig, struct, trace_name, subplot_idx, ncols, camera_scale ) - elif isinstance(struct, RayBundle): + elif _is_ray_bundle(struct): _add_ray_bundle_trace( fig, struct, @@ -316,7 +348,7 @@ def plot_scene( else: raise ValueError( "struct {} is not a Cameras, Meshes, Pointclouds,".format(struct) - + " or RayBundle object." + + " , RayBundle or HeterogeneousRayBundle object." ) # Ensure update for every subplot. @@ -401,15 +433,16 @@ def plot_batch_individually( In addition, you can include Cameras, Meshes, Pointclouds, or RayBundle of size 1 in the input. These will either be rendered in the first subplot (if extend_struct is False), or in every subplot. + RayBundle includes ImplicitronRayBundle and HeterogeneousRaybundle. Args: - batched_structs: a list of Cameras, Meshes, Pointclouds, and RayBundle - to be rendered. Each structure's corresponding batch element will be - plotted in a single subplot, resulting in n subplots for a batch of size n. - Every struct should either have the same batch size or be of batch size 1. - See extend_struct and the description above for how batch size 1 structs - are handled. Also accepts a single Cameras, Meshes, Pointclouds, and RayBundle - object, which will have each individual element plotted in its own subplot. + batched_structs: a list of Cameras, Meshes, Pointclouds and RayBundle to be + rendered. Each structure's corresponding batch element will be plotted in a + single subplot, resulting in n subplots for a batch of size n. Every struct + should either have the same batch size or be of batch size 1. See extend_struct + and the description above for how batch size 1 structs are handled. Also accepts + a single Cameras, Meshes, Pointclouds, and RayBundle object, which will have + each individual element plotted in its own subplot. viewpoint_cameras: an instance of a Cameras object providing a location to view the plotly plot from. If the batch size is equal to the number of subplots, it is a one to one mapping. @@ -451,20 +484,20 @@ def plot_batch_individually( """ # check that every batch is the same size or is size 1 - if len(batched_structs) == 0: + if _get_len(batched_structs) == 0: msg = "No structs to plot" warnings.warn(msg) return max_size = 0 if isinstance(batched_structs, list): - max_size = max(_get_struct_len(s) for s in batched_structs) + max_size = max(_get_len(s) for s in batched_structs) for struct in batched_structs: - struct_len = _get_struct_len(struct) + struct_len = _get_len(struct) if struct_len not in (1, max_size): msg = "invalid batch size {} provided: {}".format(struct_len, struct) raise ValueError(msg) else: - max_size = _get_struct_len(batched_structs) + max_size = _get_len(batched_structs) if max_size == 0: msg = "No data is provided with at least one element" @@ -475,6 +508,14 @@ def plot_batch_individually( msg = "invalid number of subplot titles" raise ValueError(msg) + # if we are dealing with HeterogeneousRayBundle of ImplicitronRayBundle create + # first indexes for faster + first_idxs = None + if _is_heterogeneous_ray_bundle(batched_structs): + # pyre-ignore[16] + cumsum = batched_structs.camera_counts.cumsum(dim=0) + first_idxs = torch.cat((cumsum.new_zeros((1,)), cumsum)) + scene_dictionary = {} # construct the scene dictionary for scene_num in range(max_size): @@ -487,16 +528,30 @@ def plot_batch_individually( if isinstance(batched_structs, list): for i, batched_struct in enumerate(batched_structs): + first_idxs = None + if _is_heterogeneous_ray_bundle(batched_structs[i]): + # pyre-ignore[16] + cumsum = batched_struct.camera_counts.cumsum(dim=0) + first_idxs = torch.cat((cumsum.new_zeros((1,)), cumsum)) # check for whether this struct needs to be extended - batched_struct_len = _get_struct_len(batched_struct) + batched_struct_len = _get_len(batched_struct) if i >= batched_struct_len and not extend_struct: continue _add_struct_from_batch( - batched_struct, scene_num, subplot_title, scene_dictionary, i + 1 + batched_struct, + scene_num, + subplot_title, + scene_dictionary, + i + 1, + first_idxs=first_idxs, ) else: # batched_structs is a single struct _add_struct_from_batch( - batched_structs, scene_num, subplot_title, scene_dictionary + batched_structs, + scene_num, + subplot_title, + scene_dictionary, + first_idxs=first_idxs, ) return plot_scene( @@ -510,6 +565,7 @@ def _add_struct_from_batch( subplot_title: str, scene_dictionary: Dict[str, Dict[str, Struct]], trace_idx: int = 1, + first_idxs: Optional[torch.Tensor] = None, ) -> None: # pragma: no cover """ Adds the struct corresponding to the given scene_num index to @@ -544,17 +600,35 @@ def _add_struct_from_batch( # torch.Tensor, torch.nn.Module]` is not a function. T = T[t_idx].unsqueeze(0) struct = CamerasBase(device=batched_struct.device, R=R, T=T) - elif isinstance(batched_struct, RayBundle): - # for RayBundle we treat the 1st dim as the batch index - struct_idx = min(scene_num, len(batched_struct.lengths) - 1) + elif _is_ray_bundle(batched_struct) and not _is_heterogeneous_ray_bundle( + batched_struct + ): + # for RayBundle we treat the camera count as the batch index + struct_idx = min(scene_num, _get_len(batched_struct) - 1) + struct = RayBundle( **{ attr: getattr(batched_struct, attr)[struct_idx] for attr in ["origins", "directions", "lengths", "xys"] } ) + elif _is_heterogeneous_ray_bundle(batched_struct): + # for RayBundle we treat the camera count as the batch index + struct_idx = min(scene_num, _get_len(batched_struct) - 1) + + struct = RayBundle( + **{ + attr: getattr(batched_struct, attr)[ + # pyre-ignore[16] + first_idxs[struct_idx] : first_idxs[struct_idx + 1] + ] + for attr in ["origins", "directions", "lengths", "xys"] + } + ) + else: # batched meshes and pointclouds are indexable - struct_idx = min(scene_num, len(batched_struct) - 1) + struct_idx = min(scene_num, _get_len(batched_struct) - 1) + # pyre-ignore[16] struct = batched_struct[struct_idx] trace_name = "trace{}-{}".format(scene_num + 1, trace_idx) scene_dictionary[subplot_title][trace_name] = struct @@ -753,7 +827,7 @@ def _add_camera_trace( def _add_ray_bundle_trace( fig: go.Figure, - ray_bundle: RayBundle, + ray_bundle: Union[RayBundle, HeterogeneousRayBundle], trace_name: str, subplot_idx: int, ncols: int, @@ -763,12 +837,13 @@ def _add_ray_bundle_trace( line_width: int, ) -> None: # pragma: no cover """ - Adds a trace rendering a RayBundle object to the passed in figure, with - a given name and in a specific subplot. + Adds a trace rendering a ray bundle object + to the passed in figure, with a given name and in a specific subplot. Args: fig: plotly figure to add the trace within. - cameras: the Cameras object to render. It can be batched. + ray_bundle: the RayBundle, ImplicitronRayBundle or HeterogeneousRaybundle to render. + It can be batched. trace_name: name to label the trace with. subplot_idx: identifies the subplot, with 0 being the top left. ncols: the number of subplots per row. diff --git a/tests/implicitron/test_ray_point_refiner.py b/tests/implicitron/test_ray_point_refiner.py index fb512c24e..9373edc22 100644 --- a/tests/implicitron/test_ray_point_refiner.py +++ b/tests/implicitron/test_ray_point_refiner.py @@ -8,7 +8,7 @@ import torch from pytorch3d.implicitron.models.renderer.ray_point_refiner import RayPointRefiner -from pytorch3d.renderer import RayBundle +from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle from tests.common_testing import TestCaseMixin @@ -24,7 +24,14 @@ def test_simple(self): add_input_samples=add_input_samples, ) lengths = torch.arange(length, dtype=torch.float32).expand(3, 25, length) - bundle = RayBundle(lengths=lengths, origins=None, directions=None, xys=None) + bundle = ImplicitronRayBundle( + lengths=lengths, + origins=None, + directions=None, + xys=None, + camera_ids=None, + camera_counts=None, + ) weights = torch.ones(3, 25, length) refined = ray_point_refiner(bundle, weights) diff --git a/tests/implicitron/test_srn.py b/tests/implicitron/test_srn.py index f6905ef4d..311bbaa6d 100644 --- a/tests/implicitron/test_srn.py +++ b/tests/implicitron/test_srn.py @@ -13,8 +13,10 @@ SRNImplicitFunction, SRNPixelGenerator, ) +from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle from pytorch3d.implicitron.tools.config import get_default_args -from pytorch3d.renderer import PerspectiveCameras, RayBundle +from pytorch3d.renderer import PerspectiveCameras + from tests.common_testing import TestCaseMixin _BATCH_SIZE: int = 3 @@ -31,12 +33,17 @@ def setUp(self) -> None: def test_pixel_generator(self): SRNPixelGenerator() - def _get_bundle(self, *, device) -> RayBundle: + def _get_bundle(self, *, device) -> ImplicitronRayBundle: origins = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device) directions = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device) lengths = torch.rand(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, device=device) - bundle = RayBundle( - lengths=lengths, origins=origins, directions=directions, xys=None + bundle = ImplicitronRayBundle( + lengths=lengths, + origins=origins, + directions=directions, + xys=None, + camera_ids=None, + camera_counts=None, ) return bundle