From e7c609f1980780e1a3df1525011425ffd5aa4e7a Mon Sep 17 00:00:00 2001 From: Darijan Gudelj Date: Fri, 26 Aug 2022 08:47:30 -0700 Subject: [PATCH] Decoding functions Summary: Added replacable decoding functions which will be applied after the voxel grid to get color and density Reviewed By: bottler Differential Revision: D38829763 fbshipit-source-id: f21ce206c1c19548206ea2ce97d7ebea3de30a23 --- .../implicit_function/decoding_functions.py | 158 +++++++++++++----- .../neural_radiance_field.py | 3 +- tests/implicitron/test_decoding_functions.py | 34 ++++ 3 files changed, 148 insertions(+), 47 deletions(-) create mode 100644 tests/implicitron/test_decoding_functions.py diff --git a/pytorch3d/implicitron/models/implicit_function/decoding_functions.py b/pytorch3d/implicitron/models/implicit_function/decoding_functions.py index 4222a56e0..71b9bd172 100644 --- a/pytorch3d/implicitron/models/implicit_function/decoding_functions.py +++ b/pytorch3d/implicitron/models/implicit_function/decoding_functions.py @@ -4,16 +4,66 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +""" +This file contains + - modules which get used by ImplicitFunction objects for decoding an embedding defined in + space, e.g. to color or opacity. + - DecoderFunctionBase and its subclasses, which wrap some of those modules, providing + some such modules as an extension point which an ImplicitFunction object could use. +""" + import logging from typing import Optional, Tuple import torch +from pytorch3d.implicitron.tools.config import ( + Configurable, + registry, + ReplaceableBase, + run_auto_creation, +) + logger = logging.getLogger(__name__) -class MLPWithInputSkips(torch.nn.Module): +class DecoderFunctionBase(ReplaceableBase, torch.nn.Module): + """ + Decoding function is a torch.nn.Module which takes the embedding of a location in + space and transforms it into the required quantity (for example density and color). + """ + + def __post_init__(self): + super().__init__() + + def forward( + self, features: torch.Tensor, z: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Args: + features (torch.Tensor): tensor of shape (batch, ..., num_in_features) + z: optional tensor to append to parts of the decoding function + Returns: + decoded_features (torch.Tensor) : tensor of + shape (batch, ..., num_out_features) + """ + raise NotImplementedError() + + +@registry.register +class IdentityDecoder(DecoderFunctionBase): + """ + Decoding function which returns its input. + """ + + def forward( + self, features: torch.Tensor, z: Optional[torch.Tensor] = None + ) -> torch.Tensor: + return features + + +class MLPWithInputSkips(Configurable, torch.nn.Module): """ Implements the multi-layer perceptron architecture of the Neural Radiance Field. @@ -31,70 +81,68 @@ class MLPWithInputSkips(torch.nn.Module): and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng: NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis, ECCV2020 + + Members: + n_layers: The number of linear layers of the MLP. + input_dim: The number of channels of the input tensor. + output_dim: The number of channels of the output. + skip_dim: The number of channels of the tensor `z` appended when + evaluating the skip layers. + hidden_dim: The number of hidden units of the MLP. + input_skips: The list of layer indices at which we append the skip + tensor `z`. """ - def _make_affine_layer(self, input_dim, hidden_dim): - l1 = torch.nn.Linear(input_dim, hidden_dim * 2) - l2 = torch.nn.Linear(hidden_dim * 2, hidden_dim * 2) - _xavier_init(l1) - _xavier_init(l2) - return torch.nn.Sequential(l1, torch.nn.ReLU(True), l2) + n_layers: int = 8 + input_dim: int = 39 + output_dim: int = 256 + skip_dim: int = 39 + hidden_dim: int = 256 + input_skips: Tuple[int, ...] = (5,) + skip_affine_trans: bool = False + no_last_relu = False - def _apply_affine_layer(self, layer, x, z): - mu_log_std = layer(z) - mu, log_std = mu_log_std.split(mu_log_std.shape[-1] // 2, dim=-1) - std = torch.nn.functional.softplus(log_std) - return (x - mu) * std - - def __init__( - self, - n_layers: int = 8, - input_dim: int = 39, - output_dim: int = 256, - skip_dim: int = 39, - hidden_dim: int = 256, - input_skips: Tuple[int, ...] = (5,), - skip_affine_trans: bool = False, - no_last_relu=False, - ): - """ - Args: - n_layers: The number of linear layers of the MLP. - input_dim: The number of channels of the input tensor. - output_dim: The number of channels of the output. - skip_dim: The number of channels of the tensor `z` appended when - evaluating the skip layers. - hidden_dim: The number of hidden units of the MLP. - input_skips: The list of layer indices at which we append the skip - tensor `z`. - """ + def __post_init__(self): super().__init__() layers = [] skip_affine_layers = [] - for layeri in range(n_layers): - dimin = hidden_dim if layeri > 0 else input_dim - dimout = hidden_dim if layeri + 1 < n_layers else output_dim + for layeri in range(self.n_layers): + dimin = self.hidden_dim if layeri > 0 else self.input_dim + dimout = self.hidden_dim if layeri + 1 < self.n_layers else self.output_dim - if layeri > 0 and layeri in input_skips: - if skip_affine_trans: + if layeri > 0 and layeri in self.input_skips: + if self.skip_affine_trans: skip_affine_layers.append( - self._make_affine_layer(skip_dim, hidden_dim) + self._make_affine_layer(self.skip_dim, self.hidden_dim) ) else: - dimin = hidden_dim + skip_dim + dimin = self.hidden_dim + self.skip_dim linear = torch.nn.Linear(dimin, dimout) _xavier_init(linear) layers.append( torch.nn.Sequential(linear, torch.nn.ReLU(True)) - if not no_last_relu or layeri + 1 < n_layers + if not self.no_last_relu or layeri + 1 < self.n_layers else linear ) self.mlp = torch.nn.ModuleList(layers) - if skip_affine_trans: + if self.skip_affine_trans: self.skip_affines = torch.nn.ModuleList(skip_affine_layers) - self._input_skips = set(input_skips) - self._skip_affine_trans = skip_affine_trans + self._input_skips = set(self.input_skips) + self._skip_affine_trans = self.skip_affine_trans + + def _make_affine_layer(self, input_dim, hidden_dim): + l1 = torch.nn.Linear(input_dim, hidden_dim * 2) + l2 = torch.nn.Linear(hidden_dim * 2, hidden_dim * 2) + _xavier_init(l1) + _xavier_init(l2) + return torch.nn.Sequential(l1, torch.nn.ReLU(True), l2) + + def _apply_affine_layer(self, layer, x, z): + mu_log_std = layer(z) + mu, log_std = mu_log_std.split(mu_log_std.shape[-1] // 2, dim=-1) + std = torch.nn.functional.softplus(log_std) + return (x - mu) * std def forward(self, x: torch.Tensor, z: Optional[torch.Tensor] = None): """ @@ -121,6 +169,24 @@ def forward(self, x: torch.Tensor, z: Optional[torch.Tensor] = None): return y +@registry.register +class MLPDecoder(DecoderFunctionBase): + """ + Decoding function which uses `MLPWithIputSkips` to convert the embedding to output. + """ + + network: MLPWithInputSkips + + def __post_init__(self): + super().__post_init__() + run_auto_creation(self) + + def forward( + self, features: torch.Tensor, z: Optional[torch.Tensor] = None + ) -> torch.Tensor: + return self.network(features, z) + + class TransformerWithInputSkips(torch.nn.Module): def __init__( self, diff --git a/pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py b/pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py index 285015877..78a3c8a44 100644 --- a/pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py +++ b/pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py @@ -9,7 +9,7 @@ import torch from pytorch3d.common.linear_with_repeat import LinearWithRepeat -from pytorch3d.implicitron.tools.config import registry +from pytorch3d.implicitron.tools.config import expand_args_fields, registry from pytorch3d.renderer import ray_bundle_to_ray_points, RayBundle from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.implicit import HarmonicEmbedding @@ -214,6 +214,7 @@ class NeuralRadianceFieldImplicitFunction(NeuralRadianceFieldBase): append_xyz: Tuple[int, ...] = (5,) def _construct_xyz_encoder(self, input_dim: int): + expand_args_fields(MLPWithInputSkips) return MLPWithInputSkips( self.n_layers_xyz, input_dim, diff --git a/tests/implicitron/test_decoding_functions.py b/tests/implicitron/test_decoding_functions.py new file mode 100644 index 000000000..0a8db59a1 --- /dev/null +++ b/tests/implicitron/test_decoding_functions.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import unittest + +import torch + +from pytorch3d.implicitron.models.implicit_function.decoding_functions import ( + IdentityDecoder, + MLPDecoder, +) +from pytorch3d.implicitron.tools.config import expand_args_fields + +from tests.common_testing import TestCaseMixin + + +class TestVoxelGrids(TestCaseMixin, unittest.TestCase): + def setUp(self): + torch.manual_seed(42) + expand_args_fields(IdentityDecoder) + expand_args_fields(MLPDecoder) + + def test_identity_function(self, in_shape=(33, 4, 1), n_tests=2): + """ + Test that identity function returns its input + """ + func = IdentityDecoder() + for _ in range(n_tests): + _in = torch.randn(in_shape) + assert torch.allclose(func(_in), _in)