This repository is an unofficial implementation of NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis, using Flax and the Linen API.
B. Mildenhall, P.P. Srinivasan, M. Tancik, J.T. Barron, R. Ramamoorthi and R. Ng, NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis, 2020, ECCV, arXiv:2003.08934 [cs.CV].
Original repository can be found in bmild/nerf.
Neural Radiance Fields (NeRF) is a method for synthesizing novel views of complex scenes, by optimizing an underlying continuous volumetric scene function using a sparse set of input views. Views are synthesized by querying 5D coordinates (spatial location (x, y, z) and viewing direction (θ, ϕ)) along camera rays and using classic volume rendering techniques to project the output colors and densities into an image.
This implementation tries to be as close as possible to the original source, bringing some code optimizations and using the flexibility and native multi device (GPUs and TPUs) support in JAX.
Most of the comments are from the original work, which are very helpful for understanding the model steps.
Install jax
and jaxlib
according to your platform configuration. Then, install the necessary dependencies with:
pip install --upgrade clu flax imageio imageio-ffmpeg ml_collections optax pandas tensorboard 'tensorflow>=2.4' tqdm
There are three subsets of data used in the original publication that can be downloaded from nerf_data:
- Blender from NeRF authors (
nerf_synthetic.zip
) - DeepVoxels from Vincent Sitzmann (
nerf_real_360.zip
) - LLFF from NeRF authors (
nerf_llff_data.zip
)
In addition, there is:
- nerf_example_data is limited to the
lego
(from Blender) andfern
(from LLFF) scenes - tiny_nerf_data is a low resolution
lego
used in the simplified notebook example
Required parameters to run the training are:
--data_dir
: directory where data is placed--model_dir
: model saving location--config
: configuration parameters
python main.py \
--data_dir=/data/nerf_synthetic \
--model_dir=logs \
--config=configs/test_blender_lego.py
Configuration flag is defined using config_flags
, which allows overriding configuration fields, and can be done as follows:
python main.py \
--data_dir=/data/nerf_synthetic \
--model_dir=logs \
--config=configs/test_blender_lego.py \
--config.num_samples=128 \
--config.i_print=250
NOTE: check and understand the effect of default parameters in configs/default.py
to avoid confusion when passing arguments to the model.
All examples were run on an NVIDIA RTX 2080Ti. Examples prior to deterministic datasets are available in e81d608.
Commands
python main.py \
--data_dir=/data/nerf_synthetic \
--model_dir=logs_lego_64 \
--config=configs/test_blender_lego.py \
--config.batching=True \
--config.i_img=10000 \
--config.i_weights=10000
python render.py \
--data_dir=/data/nerf_synthetic \
--model_dir=logs_lego_64 \
--config=configs/test_blender_lego.py \
--config.render_factor=1 \
--config.testskip=0 \
--render_video_set=test
python main.py \
--data_dir=/data/nerf_synthetic \
--model_dir=logs_lego_128 \
--config=configs/paper_blender_lego.py \
--config.batching=True \
--config.i_img=10000 \
--config.i_weights=10000
python render.py \
--data_dir=/data/nerf_synthetic \
--model_dir=logs_lego_128 \
--config=configs/paper_blender_lego.py \
--config.render_factor=1 \
--config.testskip=0 \
--render_video_set=test
Checkpoint path | Test set PSNR | Paper PSNR | TensorBoard.dev |
---|---|---|---|
lego_400_64 | 31.48 | - | 2021-01-15 |
lego_800_128 | 32.29 | 32.54 | 2021-01-15 |
- You can test or debug multiple devices in a CPU only installation using
XLA_FLAGS
environment variable (more information in JAX #1408). To simulate 4 devices:
XLA_FLAGS="--xla_force_host_platform_device_count=4 xla_cpu_multi_thread_eigen=False intra_op_parallelism_threads=1"
-
Try to minimize time spent on rendering intermediate results during training (
i_video
,i_testset
) and rely on validation results in TensorBoard. Either save intermediate checkpoints and render after training or userender_factor
andtestskip
to your advantage. -
Here are some recommendations for reducing GPU memory footprint:
- Use
nn.remat
decorator in your network module (more aboutjax.remat
in JAX #1749) - Decrease model parameters (
net_depth
,net_width
,num_importance
,num_rand
,num_samples
) - Using
bfloat16
will decrease memory usage by half, but the low precision reduces performance by a big margin
- Use
-
The original repository (bmild/nerf/issues) has many good comments and explanations from the authors and participants, which help to better understand the limitations and applications for this approach
-
kwea123/nerf_pl is another implementation, using PyTorch Lightning, that has many explanations and applications for your trained models
-
google/jaxnerf is kind of an official version of NeRF with JAX and Flax
-
Training these models in Colab with TPUs is a bit of a stretch (FAQ - Resource Limits), although you can use it for rendering (800px square image takes ~26s in an NVIDIA RTX 2080Ti vs ~7s in a TPUv2). Add the following commands to the top of your file:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
- Rendering routines use
lax.map
, which is convenient for shaping outputs and fast at execution, although reshaping is a nuisance in some cases. Wait for mask redesign or rethink the execution. - Most of the processes are done with batches of rays, rewrite everything for a single ray and
vmap/pmap/xmap
as needed (wait for JAX unified map API JAX#2939). - Add function docs and lint