Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Docs #34

Merged
merged 12 commits into from
Sep 28, 2022
46 changes: 6 additions & 40 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,58 +4,24 @@

This is a **tiny** tootlbox for **accelerating** NeRF training & rendering using PyTorch CUDA extensions. Plug-and-play for most of the NeRFs!

## Examples: Instant-NGP NeRF
## Examples:

``` bash
# Instant-NGP NeRF
python examples/train_ngp_nerf.py --train_split trainval --scene lego
```

Performance:

| PSNR | Lego | Mic | Materials | Chair | Hotdog |
| - | - | - | - | - | - |
| Papers (5mins) | 36.39 | 36.22 | 29.78 | 35.00 | 37.40 |
| Ours (~5mins) | 36.61 | 37.62 | 30.11 | 36.09 | 38.09 |
| Exact training time | 300s | 274s | 266s | 341s | 277s |


## Examples: Vanilla MLP NeRF

``` bash
# Vanilla MLP NeRF
python examples/train_mlp_nerf.py --train_split train --scene lego
```

Performance:

| PNSR | Lego | Mic | Materials | Chair | Hotdog |
| - | - | - | - | - | - |
| Paper (~2days) | 32.54 | 32.91 | 29.62 | 33.00 | 36.18 |
| Ours (~45mins) | 33.21 | 33.36 | 29.48 | 32.79 | 35.54 |

## Examples: MLP NeRF on Dynamic objects (D-NeRF)

```bash
# MLP NeRF on Dynamic objects (D-NeRF)
python examples/train_mlp_dnerf.py --train_split train --scene lego
```

Performance:

| | Lego | Stand Up |
| - | - | - |
| Paper (~2days) | 21.64 | 32.79 |
| Ours (~45mins) | 24.66 | 33.98 |


## Examples: NGP on unbounded scene

On MipNeRF360 Garden scene

```bash
python examples/train_ngp_nerf.py --train_split train --scene garden --aabb="-4,-4,-4,4,4,4" --unbounded --cone_angle=0.004
# NGP on MipNeRF360 unbounded scene
python examples/train_ngp_nerf.py --train_split train --scene garden --auto_aabb --unbounded --cone_angle=0.004
```

Performance:

| | Garden |
| - | - |
| Ours | 25.13 |
15 changes: 14 additions & 1 deletion docs/source/examples/dnerf.rst
Original file line number Diff line number Diff line change
@@ -1,2 +1,15 @@
Dynamic Scene
====================
====================



+----------------------+----------+---------+-------+---------+-------+--------+---------+-------+
| | bouncing | hell | hook | jumping | lego | mutant | standup | trex |
| | balls | warrior | | jacks | | | | |
+======================+==========+=========+=======+=========+=======+========+=========+=======+
| Paper (PSNR: ~2day) | 38.93 | 25.02 | 29.25 | 32.80 | 21.64 | 31.29 | 32.79 | 31.75 |
+----------------------+----------+---------+-------+---------+-------+--------+---------+-------+
| Ours (PSNR: ~50min) | 39.60 | 22.41 | 30.64 | 29.79 | 24.75 | 35.20 | 34.50 | 31.83 |
+----------------------+----------+---------+-------+---------+-------+--------+---------+-------+
| Ours (Training time)| 45min | 49min | 51min | 46min | 53min | 57min | 49min | 46min |
+----------------------+----------+---------+-------+---------+-------+--------+---------+-------+
14 changes: 13 additions & 1 deletion docs/source/examples/unbounded.rst
Original file line number Diff line number Diff line change
@@ -1,2 +1,14 @@
Unbounded Scene
====================
====================


+----------------------+-------+
| | Garden|
| | |
+======================+=======+
| Paper (PSNR: ~ days) | 26.98 |
+----------------------+-------+
| Ours (PSNR: ~ 1 hr) | 25.41 |
+----------------------+-------+
| Ours (Training time)| 58min |
+----------------------+-------+
6 changes: 3 additions & 3 deletions docs/source/examples/vanilla.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ Benchmarks
| | Lego | Mic | Materials |Chair |Hotdog | Ficus | Drums | Ship |
| | | | | | | | | |
+======================+=======+=======+============+=======+========+========+========+========+
| Paper (PSNR: 5min) | 32.54 | 32.91 | 29.62 | 33.00 | 36.18 | 30.13 | 25.01 | 28.65 |
| Paper (PSNR: 1~2days)| 32.54 | 32.91 | 29.62 | 33.00 | 36.18 | 30.13 | 25.01 | 28.65 |
+----------------------+-------+-------+------------+-------+--------+--------+--------+--------+
| Ours (PSNR) | XX.XX | XX.XX | XX.XX | XX.XX | XX.XX | XX.XX | XX.XX | XX.XX |
| Ours (PSNR: ~50min) | 33.69 | 33.76 | 29.73 | 33.32 | 35.80 | 32.52 | 25.39 | 28.18 |
+----------------------+-------+-------+------------+-------+--------+--------+--------+--------+
| Ours (Training time)| XXmin | XXmin | XXmin | 45min | XXmin | XXmin | 41min | XXmin |
| Ours (Training time)| 58min | 53min | 46min | 62min | 56min | 42min | 52min | 49min |
+----------------------+-------+-------+------------+-------+--------+--------+--------+--------+

.. _`github repository`: : https://github.com/KAIR-BAIR/nerfacc/
6 changes: 6 additions & 0 deletions examples/datasets/nerf_360_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,12 @@ class SubjectLoader(torch.utils.data.Dataset):
SPLITS = ["train", "test"]
SUBJECT_IDS = [
"garden",
"bicycle",
"bonsai",
"counter",
"kitchen",
"room",
"stump",
]

OPENGL_CAMERA = False
Expand Down
50 changes: 31 additions & 19 deletions examples/radiance_fields/ngp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,31 @@ def backward(ctx, g): # pylint: disable=arguments-differ
trunc_exp = _TruncExp.apply


def contract_to_unisphere(
x: torch.Tensor,
aabb: torch.Tensor,
eps: float = 1e-6,
derivative: bool = False,
):
aabb_min, aabb_max = torch.split(aabb, 3, dim=-1)
x = (x - aabb_min) / (aabb_max - aabb_min)
x = x * 2 - 1 # aabb is at [-1, 1]
mag = x.norm(dim=-1, keepdim=True)
mask = mag.squeeze(-1) > 1

if derivative:
dev = (2 * mag - 1) / mag**2 + 2 * x**2 * (
1 / mag**3 - (2 * mag - 1) / mag**4
)
dev[~mask] = 1.0
dev = torch.clamp(dev, min=eps)
return dev
else:
x[mask] = (2 - 1 / mag[mask]) * (x[mask] / mag[mask])
x = x / 4 + 0.5 # [-inf, inf] is at [0, 1]
return x


class NGPradianceField(torch.nn.Module):
"""Instance-NGP radiance Field"""

Expand Down Expand Up @@ -114,31 +139,18 @@ def __init__(

def query_opacity(self, x, step_size):
density = self.query_density(x)
aabb_min, aabb_max = torch.split(self.aabb, self.num_dim, dim=-1)
if self.unbounded:
# TODO: [revisit] is this necessary?
# 1.0 / derivative of tanh contraction
x = (x - aabb_min) / (aabb_max - aabb_min)
x = x - 0.5
scaling = 1.0 / (
torch.clamp(1.0 - torch.tanh(x) ** 2, min=1e6) * 0.5
)
scaling = scaling * (aabb_max - aabb_min)
else:
scaling = aabb_max - aabb_min
step_size = step_size * scaling.norm(dim=-1, keepdim=True)
# if the density is small enough those two are the same.
# opacity = 1.0 - torch.exp(-density * step_size)
# NOTE: In principle, we should use the following formula to scale
# up the step size, but in practice, it is somehow not helpful.
# derivitive = contract_to_unisphere(x, self.aabb, derivative=True)
# step_size = step_size / derivitive.norm(dim=-1, keepdim=True)
pass
opacity = density * step_size
return opacity

def query_density(self, x, return_feat: bool = False):
if self.unbounded:
# tanh contraction
aabb_min, aabb_max = torch.split(self.aabb, self.num_dim, dim=-1)
x = (x - aabb_min) / (aabb_max - aabb_min)
x = x - 0.5
x = (torch.tanh(x) + 1) * 0.5
x = contract_to_unisphere(x, self.aabb)
else:
aabb_min, aabb_max = torch.split(self.aabb, self.num_dim, dim=-1)
x = (x - aabb_min) / (aabb_max - aabb_min)
Expand Down
98 changes: 59 additions & 39 deletions examples/train_ngp_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@
"ship",
# mipnerf360 unbounded
"garden",
"bicycle",
"bonsai",
"counter",
"kitchen",
"room",
"stump",
],
help="which scene to use",
)
Expand All @@ -61,58 +67,27 @@
action="store_true",
help="whether to use unbounded rendering",
)
parser.add_argument(
"--auto_aabb",
action="store_true",
help="whether to automatically compute the aabb",
)
parser.add_argument("--cone_angle", type=float, default=0.0)
args = parser.parse_args()

render_n_samples = 1024

# setup the scene bounding box.
if args.unbounded:
print("Using unbounded rendering")
contraction_type = ContractionType.UN_BOUNDED_SPHERE
# contraction_type = ContractionType.UN_BOUNDED_TANH
scene_aabb = None
near_plane = 0.2
far_plane = 1e4
render_step_size = 1e-2
else:
contraction_type = ContractionType.AABB
scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device)
near_plane = None
far_plane = None
render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max()
* math.sqrt(3)
/ render_n_samples
).item()

# setup the radiance field we want to train.
max_steps = 20000
grad_scaler = torch.cuda.amp.GradScaler(2**10)
radiance_field = NGPradianceField(
aabb=args.aabb,
unbounded=args.unbounded,
).to(device)
optimizer = torch.optim.Adam(
radiance_field.parameters(), lr=1e-2, eps=1e-15
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[max_steps // 2, max_steps * 3 // 4, max_steps * 9 // 10],
gamma=0.33,
)

# setup the dataset
train_dataset_kwargs = {}
test_dataset_kwargs = {}
if args.scene == "garden":
if args.unbounded:
from datasets.nerf_360_v2 import SubjectLoader

data_root_fp = "/home/ruilongli/data/360_v2/"
target_sample_batch_size = 1 << 20
train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4}
test_dataset_kwargs = {"factor": 4}
grid_resolution = 128
grid_resolution = 256
else:
from datasets.nerf_synthetic import SubjectLoader

Expand Down Expand Up @@ -143,6 +118,51 @@
test_dataset.camtoworlds = test_dataset.camtoworlds.to(device)
test_dataset.K = test_dataset.K.to(device)

if args.auto_aabb:
camera_locs = torch.cat(
[train_dataset.camtoworlds, test_dataset.camtoworlds]
)[:, :3, -1]
args.aabb = torch.cat(
[camera_locs.min(dim=0).values, camera_locs.max(dim=0).values]
).tolist()
print("Using auto aabb", args.aabb)

# setup the scene bounding box.
if args.unbounded:
print("Using unbounded rendering")
contraction_type = ContractionType.UN_BOUNDED_SPHERE
# contraction_type = ContractionType.UN_BOUNDED_TANH
scene_aabb = None
near_plane = 0.2
far_plane = 1e4
render_step_size = 1e-2
else:
contraction_type = ContractionType.AABB
scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device)
near_plane = None
far_plane = None
render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max()
* math.sqrt(3)
/ render_n_samples
).item()

# setup the radiance field we want to train.
max_steps = 40000 if args.unbounded else 20000
grad_scaler = torch.cuda.amp.GradScaler(2**10)
radiance_field = NGPradianceField(
aabb=args.aabb,
unbounded=args.unbounded,
).to(device)
optimizer = torch.optim.Adam(
radiance_field.parameters(), lr=1e-2, eps=1e-15
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[max_steps // 2, max_steps * 3 // 4, max_steps * 9 // 10],
gamma=0.33,
)

occupancy_grid = OccupancyGrid(
roi_aabb=args.aabb,
resolution=grid_resolution,
Expand Down Expand Up @@ -201,7 +221,7 @@
optimizer.step()
scheduler.step()

if step % 5000 == 0:
if step % 10000 == 0:
elapsed_time = time.time() - tic
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
print(
Expand Down
5 changes: 5 additions & 0 deletions nerfacc/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ def _update(
x = (
grid_coords + torch.rand_like(grid_coords, dtype=torch.float32)
) / self.resolution
if self._contraction_type == ContractionType.UN_BOUNDED_SPHERE:
# only the points inside the sphere are valid
mask = (x - 0.5).norm(dim=1) < 0.5
x = x[mask]
indices = indices[mask]
# voxel coordinates [0, 1]^3 -> world
x = contract_inv(
x,
Expand Down