Skip to content

Commit

Permalink
Decoding functions
Browse files Browse the repository at this point in the history
Summary: Added replacable decoding functions which will be applied after the voxel grid to get color and density

Reviewed By: bottler

Differential Revision: D38829763

fbshipit-source-id: f21ce206c1c19548206ea2ce97d7ebea3de30a23
  • Loading branch information
Darijan Gudelj authored and facebook-github-bot committed Aug 26, 2022
1 parent 24f5f4a commit e7c609f
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 47 deletions.
158 changes: 112 additions & 46 deletions pytorch3d/implicitron/models/implicit_function/decoding_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,66 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
This file contains
- modules which get used by ImplicitFunction objects for decoding an embedding defined in
space, e.g. to color or opacity.
- DecoderFunctionBase and its subclasses, which wrap some of those modules, providing
some such modules as an extension point which an ImplicitFunction object could use.
"""

import logging

from typing import Optional, Tuple

import torch

from pytorch3d.implicitron.tools.config import (
Configurable,
registry,
ReplaceableBase,
run_auto_creation,
)

logger = logging.getLogger(__name__)


class MLPWithInputSkips(torch.nn.Module):
class DecoderFunctionBase(ReplaceableBase, torch.nn.Module):
"""
Decoding function is a torch.nn.Module which takes the embedding of a location in
space and transforms it into the required quantity (for example density and color).
"""

def __post_init__(self):
super().__init__()

def forward(
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Args:
features (torch.Tensor): tensor of shape (batch, ..., num_in_features)
z: optional tensor to append to parts of the decoding function
Returns:
decoded_features (torch.Tensor) : tensor of
shape (batch, ..., num_out_features)
"""
raise NotImplementedError()


@registry.register
class IdentityDecoder(DecoderFunctionBase):
"""
Decoding function which returns its input.
"""

def forward(
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
) -> torch.Tensor:
return features


class MLPWithInputSkips(Configurable, torch.nn.Module):
"""
Implements the multi-layer perceptron architecture of the Neural Radiance Field.
Expand All @@ -31,70 +81,68 @@ class MLPWithInputSkips(torch.nn.Module):
and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng:
NeRF: Representing Scenes as Neural Radiance Fields for View
Synthesis, ECCV2020
Members:
n_layers: The number of linear layers of the MLP.
input_dim: The number of channels of the input tensor.
output_dim: The number of channels of the output.
skip_dim: The number of channels of the tensor `z` appended when
evaluating the skip layers.
hidden_dim: The number of hidden units of the MLP.
input_skips: The list of layer indices at which we append the skip
tensor `z`.
"""

def _make_affine_layer(self, input_dim, hidden_dim):
l1 = torch.nn.Linear(input_dim, hidden_dim * 2)
l2 = torch.nn.Linear(hidden_dim * 2, hidden_dim * 2)
_xavier_init(l1)
_xavier_init(l2)
return torch.nn.Sequential(l1, torch.nn.ReLU(True), l2)
n_layers: int = 8
input_dim: int = 39
output_dim: int = 256
skip_dim: int = 39
hidden_dim: int = 256
input_skips: Tuple[int, ...] = (5,)
skip_affine_trans: bool = False
no_last_relu = False

def _apply_affine_layer(self, layer, x, z):
mu_log_std = layer(z)
mu, log_std = mu_log_std.split(mu_log_std.shape[-1] // 2, dim=-1)
std = torch.nn.functional.softplus(log_std)
return (x - mu) * std

def __init__(
self,
n_layers: int = 8,
input_dim: int = 39,
output_dim: int = 256,
skip_dim: int = 39,
hidden_dim: int = 256,
input_skips: Tuple[int, ...] = (5,),
skip_affine_trans: bool = False,
no_last_relu=False,
):
"""
Args:
n_layers: The number of linear layers of the MLP.
input_dim: The number of channels of the input tensor.
output_dim: The number of channels of the output.
skip_dim: The number of channels of the tensor `z` appended when
evaluating the skip layers.
hidden_dim: The number of hidden units of the MLP.
input_skips: The list of layer indices at which we append the skip
tensor `z`.
"""
def __post_init__(self):
super().__init__()
layers = []
skip_affine_layers = []
for layeri in range(n_layers):
dimin = hidden_dim if layeri > 0 else input_dim
dimout = hidden_dim if layeri + 1 < n_layers else output_dim
for layeri in range(self.n_layers):
dimin = self.hidden_dim if layeri > 0 else self.input_dim
dimout = self.hidden_dim if layeri + 1 < self.n_layers else self.output_dim

if layeri > 0 and layeri in input_skips:
if skip_affine_trans:
if layeri > 0 and layeri in self.input_skips:
if self.skip_affine_trans:
skip_affine_layers.append(
self._make_affine_layer(skip_dim, hidden_dim)
self._make_affine_layer(self.skip_dim, self.hidden_dim)
)
else:
dimin = hidden_dim + skip_dim
dimin = self.hidden_dim + self.skip_dim

linear = torch.nn.Linear(dimin, dimout)
_xavier_init(linear)
layers.append(
torch.nn.Sequential(linear, torch.nn.ReLU(True))
if not no_last_relu or layeri + 1 < n_layers
if not self.no_last_relu or layeri + 1 < self.n_layers
else linear
)
self.mlp = torch.nn.ModuleList(layers)
if skip_affine_trans:
if self.skip_affine_trans:
self.skip_affines = torch.nn.ModuleList(skip_affine_layers)
self._input_skips = set(input_skips)
self._skip_affine_trans = skip_affine_trans
self._input_skips = set(self.input_skips)
self._skip_affine_trans = self.skip_affine_trans

def _make_affine_layer(self, input_dim, hidden_dim):
l1 = torch.nn.Linear(input_dim, hidden_dim * 2)
l2 = torch.nn.Linear(hidden_dim * 2, hidden_dim * 2)
_xavier_init(l1)
_xavier_init(l2)
return torch.nn.Sequential(l1, torch.nn.ReLU(True), l2)

def _apply_affine_layer(self, layer, x, z):
mu_log_std = layer(z)
mu, log_std = mu_log_std.split(mu_log_std.shape[-1] // 2, dim=-1)
std = torch.nn.functional.softplus(log_std)
return (x - mu) * std

def forward(self, x: torch.Tensor, z: Optional[torch.Tensor] = None):
"""
Expand All @@ -121,6 +169,24 @@ def forward(self, x: torch.Tensor, z: Optional[torch.Tensor] = None):
return y


@registry.register
class MLPDecoder(DecoderFunctionBase):
"""
Decoding function which uses `MLPWithIputSkips` to convert the embedding to output.
"""

network: MLPWithInputSkips

def __post_init__(self):
super().__post_init__()
run_auto_creation(self)

def forward(
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
) -> torch.Tensor:
return self.network(features, z)


class TransformerWithInputSkips(torch.nn.Module):
def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
from pytorch3d.implicitron.tools.config import registry
from pytorch3d.implicitron.tools.config import expand_args_fields, registry
from pytorch3d.renderer import ray_bundle_to_ray_points, RayBundle
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit import HarmonicEmbedding
Expand Down Expand Up @@ -214,6 +214,7 @@ class NeuralRadianceFieldImplicitFunction(NeuralRadianceFieldBase):
append_xyz: Tuple[int, ...] = (5,)

def _construct_xyz_encoder(self, input_dim: int):
expand_args_fields(MLPWithInputSkips)
return MLPWithInputSkips(
self.n_layers_xyz,
input_dim,
Expand Down
34 changes: 34 additions & 0 deletions tests/implicitron/test_decoding_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import unittest

import torch

from pytorch3d.implicitron.models.implicit_function.decoding_functions import (
IdentityDecoder,
MLPDecoder,
)
from pytorch3d.implicitron.tools.config import expand_args_fields

from tests.common_testing import TestCaseMixin


class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
def setUp(self):
torch.manual_seed(42)
expand_args_fields(IdentityDecoder)
expand_args_fields(MLPDecoder)

def test_identity_function(self, in_shape=(33, 4, 1), n_tests=2):
"""
Test that identity function returns its input
"""
func = IdentityDecoder()
for _ in range(n_tests):
_in = torch.randn(in_shape)
assert torch.allclose(func(_in), _in)

0 comments on commit e7c609f

Please sign in to comment.