diff --git a/README.md b/README.md index a8eb4a27..b9998568 100644 --- a/README.md +++ b/README.md @@ -4,58 +4,24 @@ This is a **tiny** tootlbox for **accelerating** NeRF training & rendering using PyTorch CUDA extensions. Plug-and-play for most of the NeRFs! -## Examples: Instant-NGP NeRF +## Examples: ``` bash +# Instant-NGP NeRF python examples/train_ngp_nerf.py --train_split trainval --scene lego ``` -Performance: - -| PSNR | Lego | Mic | Materials | Chair | Hotdog | -| - | - | - | - | - | - | -| Papers (5mins) | 36.39 | 36.22 | 29.78 | 35.00 | 37.40 | -| Ours (~5mins) | 36.61 | 37.62 | 30.11 | 36.09 | 38.09 | -| Exact training time | 300s | 274s | 266s | 341s | 277s | - - -## Examples: Vanilla MLP NeRF - ``` bash +# Vanilla MLP NeRF python examples/train_mlp_nerf.py --train_split train --scene lego ``` -Performance: - -| PNSR | Lego | Mic | Materials | Chair | Hotdog | -| - | - | - | - | - | - | -| Paper (~2days) | 32.54 | 32.91 | 29.62 | 33.00 | 36.18 | -| Ours (~45mins) | 33.21 | 33.36 | 29.48 | 32.79 | 35.54 | - -## Examples: MLP NeRF on Dynamic objects (D-NeRF) - ```bash +# MLP NeRF on Dynamic objects (D-NeRF) python examples/train_mlp_dnerf.py --train_split train --scene lego ``` -Performance: - -| | Lego | Stand Up | -| - | - | - | -| Paper (~2days) | 21.64 | 32.79 | -| Ours (~45mins) | 24.66 | 33.98 | - - -## Examples: NGP on unbounded scene - -On MipNeRF360 Garden scene - ```bash -python examples/train_ngp_nerf.py --train_split train --scene garden --aabb="-4,-4,-4,4,4,4" --unbounded --cone_angle=0.004 +# NGP on MipNeRF360 unbounded scene +python examples/train_ngp_nerf.py --train_split train --scene garden --auto_aabb --unbounded --cone_angle=0.004 ``` - -Performance: - -| | Garden | -| - | - | -| Ours | 25.13 | diff --git a/docs/source/examples/dnerf.rst b/docs/source/examples/dnerf.rst index 304fc417..9a5b01ee 100644 --- a/docs/source/examples/dnerf.rst +++ b/docs/source/examples/dnerf.rst @@ -1,2 +1,15 @@ Dynamic Scene -==================== \ No newline at end of file +==================== + + + ++----------------------+----------+---------+-------+---------+-------+--------+---------+-------+ +| | bouncing | hell | hook | jumping | lego | mutant | standup | trex | +| | balls | warrior | | jacks | | | | | ++======================+==========+=========+=======+=========+=======+========+=========+=======+ +| Paper (PSNR: ~2day) | 38.93 | 25.02 | 29.25 | 32.80 | 21.64 | 31.29 | 32.79 | 31.75 | ++----------------------+----------+---------+-------+---------+-------+--------+---------+-------+ +| Ours (PSNR: ~50min) | 39.60 | 22.41 | 30.64 | 29.79 | 24.75 | 35.20 | 34.50 | 31.83 | ++----------------------+----------+---------+-------+---------+-------+--------+---------+-------+ +| Ours (Training time)| 45min | 49min | 51min | 46min | 53min | 57min | 49min | 46min | ++----------------------+----------+---------+-------+---------+-------+--------+---------+-------+ diff --git a/docs/source/examples/unbounded.rst b/docs/source/examples/unbounded.rst index 765af557..50f3cc31 100644 --- a/docs/source/examples/unbounded.rst +++ b/docs/source/examples/unbounded.rst @@ -1,2 +1,14 @@ Unbounded Scene -==================== \ No newline at end of file +==================== + + ++----------------------+-------+ +| | Garden| +| | | ++======================+=======+ +| Paper (PSNR: ~ days) | 26.98 | ++----------------------+-------+ +| Ours (PSNR: ~ 1 hr) | 25.41 | ++----------------------+-------+ +| Ours (Training time)| 58min | ++----------------------+-------+ diff --git a/docs/source/examples/vanilla.rst b/docs/source/examples/vanilla.rst index f4d65bf8..fdd55c4c 100644 --- a/docs/source/examples/vanilla.rst +++ b/docs/source/examples/vanilla.rst @@ -11,11 +11,11 @@ Benchmarks | | Lego | Mic | Materials |Chair |Hotdog | Ficus | Drums | Ship | | | | | | | | | | | +======================+=======+=======+============+=======+========+========+========+========+ -| Paper (PSNR: 5min) | 32.54 | 32.91 | 29.62 | 33.00 | 36.18 | 30.13 | 25.01 | 28.65 | +| Paper (PSNR: 1~2days)| 32.54 | 32.91 | 29.62 | 33.00 | 36.18 | 30.13 | 25.01 | 28.65 | +----------------------+-------+-------+------------+-------+--------+--------+--------+--------+ -| Ours (PSNR) | XX.XX | XX.XX | XX.XX | XX.XX | XX.XX | XX.XX | XX.XX | XX.XX | +| Ours (PSNR: ~50min) | 33.69 | 33.76 | 29.73 | 33.32 | 35.80 | 32.52 | 25.39 | 28.18 | +----------------------+-------+-------+------------+-------+--------+--------+--------+--------+ -| Ours (Training time)| XXmin | XXmin | XXmin | 45min | XXmin | XXmin | 41min | XXmin | +| Ours (Training time)| 58min | 53min | 46min | 62min | 56min | 42min | 52min | 49min | +----------------------+-------+-------+------------+-------+--------+--------+--------+--------+ .. _`github repository`: : https://github.com/KAIR-BAIR/nerfacc/ \ No newline at end of file diff --git a/examples/datasets/nerf_360_v2.py b/examples/datasets/nerf_360_v2.py index fc0766ca..cca2b7b8 100644 --- a/examples/datasets/nerf_360_v2.py +++ b/examples/datasets/nerf_360_v2.py @@ -144,6 +144,12 @@ class SubjectLoader(torch.utils.data.Dataset): SPLITS = ["train", "test"] SUBJECT_IDS = [ "garden", + "bicycle", + "bonsai", + "counter", + "kitchen", + "room", + "stump", ] OPENGL_CAMERA = False diff --git a/examples/radiance_fields/ngp.py b/examples/radiance_fields/ngp.py index c60eb146..240abf9b 100644 --- a/examples/radiance_fields/ngp.py +++ b/examples/radiance_fields/ngp.py @@ -34,6 +34,31 @@ def backward(ctx, g): # pylint: disable=arguments-differ trunc_exp = _TruncExp.apply +def contract_to_unisphere( + x: torch.Tensor, + aabb: torch.Tensor, + eps: float = 1e-6, + derivative: bool = False, +): + aabb_min, aabb_max = torch.split(aabb, 3, dim=-1) + x = (x - aabb_min) / (aabb_max - aabb_min) + x = x * 2 - 1 # aabb is at [-1, 1] + mag = x.norm(dim=-1, keepdim=True) + mask = mag.squeeze(-1) > 1 + + if derivative: + dev = (2 * mag - 1) / mag**2 + 2 * x**2 * ( + 1 / mag**3 - (2 * mag - 1) / mag**4 + ) + dev[~mask] = 1.0 + dev = torch.clamp(dev, min=eps) + return dev + else: + x[mask] = (2 - 1 / mag[mask]) * (x[mask] / mag[mask]) + x = x / 4 + 0.5 # [-inf, inf] is at [0, 1] + return x + + class NGPradianceField(torch.nn.Module): """Instance-NGP radiance Field""" @@ -114,31 +139,18 @@ def __init__( def query_opacity(self, x, step_size): density = self.query_density(x) - aabb_min, aabb_max = torch.split(self.aabb, self.num_dim, dim=-1) if self.unbounded: - # TODO: [revisit] is this necessary? - # 1.0 / derivative of tanh contraction - x = (x - aabb_min) / (aabb_max - aabb_min) - x = x - 0.5 - scaling = 1.0 / ( - torch.clamp(1.0 - torch.tanh(x) ** 2, min=1e6) * 0.5 - ) - scaling = scaling * (aabb_max - aabb_min) - else: - scaling = aabb_max - aabb_min - step_size = step_size * scaling.norm(dim=-1, keepdim=True) - # if the density is small enough those two are the same. - # opacity = 1.0 - torch.exp(-density * step_size) + # NOTE: In principle, we should use the following formula to scale + # up the step size, but in practice, it is somehow not helpful. + # derivitive = contract_to_unisphere(x, self.aabb, derivative=True) + # step_size = step_size / derivitive.norm(dim=-1, keepdim=True) + pass opacity = density * step_size return opacity def query_density(self, x, return_feat: bool = False): if self.unbounded: - # tanh contraction - aabb_min, aabb_max = torch.split(self.aabb, self.num_dim, dim=-1) - x = (x - aabb_min) / (aabb_max - aabb_min) - x = x - 0.5 - x = (torch.tanh(x) + 1) * 0.5 + x = contract_to_unisphere(x, self.aabb) else: aabb_min, aabb_max = torch.split(self.aabb, self.num_dim, dim=-1) x = (x - aabb_min) / (aabb_max - aabb_min) diff --git a/examples/train_ngp_nerf.py b/examples/train_ngp_nerf.py index f8a90b75..f2802198 100644 --- a/examples/train_ngp_nerf.py +++ b/examples/train_ngp_nerf.py @@ -42,6 +42,12 @@ "ship", # mipnerf360 unbounded "garden", + "bicycle", + "bonsai", + "counter", + "kitchen", + "room", + "stump", ], help="which scene to use", ) @@ -61,58 +67,27 @@ action="store_true", help="whether to use unbounded rendering", ) + parser.add_argument( + "--auto_aabb", + action="store_true", + help="whether to automatically compute the aabb", + ) parser.add_argument("--cone_angle", type=float, default=0.0) args = parser.parse_args() render_n_samples = 1024 - # setup the scene bounding box. - if args.unbounded: - print("Using unbounded rendering") - contraction_type = ContractionType.UN_BOUNDED_SPHERE - # contraction_type = ContractionType.UN_BOUNDED_TANH - scene_aabb = None - near_plane = 0.2 - far_plane = 1e4 - render_step_size = 1e-2 - else: - contraction_type = ContractionType.AABB - scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device) - near_plane = None - far_plane = None - render_step_size = ( - (scene_aabb[3:] - scene_aabb[:3]).max() - * math.sqrt(3) - / render_n_samples - ).item() - - # setup the radiance field we want to train. - max_steps = 20000 - grad_scaler = torch.cuda.amp.GradScaler(2**10) - radiance_field = NGPradianceField( - aabb=args.aabb, - unbounded=args.unbounded, - ).to(device) - optimizer = torch.optim.Adam( - radiance_field.parameters(), lr=1e-2, eps=1e-15 - ) - scheduler = torch.optim.lr_scheduler.MultiStepLR( - optimizer, - milestones=[max_steps // 2, max_steps * 3 // 4, max_steps * 9 // 10], - gamma=0.33, - ) - # setup the dataset train_dataset_kwargs = {} test_dataset_kwargs = {} - if args.scene == "garden": + if args.unbounded: from datasets.nerf_360_v2 import SubjectLoader data_root_fp = "/home/ruilongli/data/360_v2/" target_sample_batch_size = 1 << 20 train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4} test_dataset_kwargs = {"factor": 4} - grid_resolution = 128 + grid_resolution = 256 else: from datasets.nerf_synthetic import SubjectLoader @@ -143,6 +118,51 @@ test_dataset.camtoworlds = test_dataset.camtoworlds.to(device) test_dataset.K = test_dataset.K.to(device) + if args.auto_aabb: + camera_locs = torch.cat( + [train_dataset.camtoworlds, test_dataset.camtoworlds] + )[:, :3, -1] + args.aabb = torch.cat( + [camera_locs.min(dim=0).values, camera_locs.max(dim=0).values] + ).tolist() + print("Using auto aabb", args.aabb) + + # setup the scene bounding box. + if args.unbounded: + print("Using unbounded rendering") + contraction_type = ContractionType.UN_BOUNDED_SPHERE + # contraction_type = ContractionType.UN_BOUNDED_TANH + scene_aabb = None + near_plane = 0.2 + far_plane = 1e4 + render_step_size = 1e-2 + else: + contraction_type = ContractionType.AABB + scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device) + near_plane = None + far_plane = None + render_step_size = ( + (scene_aabb[3:] - scene_aabb[:3]).max() + * math.sqrt(3) + / render_n_samples + ).item() + + # setup the radiance field we want to train. + max_steps = 40000 if args.unbounded else 20000 + grad_scaler = torch.cuda.amp.GradScaler(2**10) + radiance_field = NGPradianceField( + aabb=args.aabb, + unbounded=args.unbounded, + ).to(device) + optimizer = torch.optim.Adam( + radiance_field.parameters(), lr=1e-2, eps=1e-15 + ) + scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=[max_steps // 2, max_steps * 3 // 4, max_steps * 9 // 10], + gamma=0.33, + ) + occupancy_grid = OccupancyGrid( roi_aabb=args.aabb, resolution=grid_resolution, @@ -201,7 +221,7 @@ optimizer.step() scheduler.step() - if step % 5000 == 0: + if step % 10000 == 0: elapsed_time = time.time() - tic loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask]) print( diff --git a/nerfacc/grid.py b/nerfacc/grid.py index 3f4cc0a7..33b862b0 100644 --- a/nerfacc/grid.py +++ b/nerfacc/grid.py @@ -177,6 +177,11 @@ def _update( x = ( grid_coords + torch.rand_like(grid_coords, dtype=torch.float32) ) / self.resolution + if self._contraction_type == ContractionType.UN_BOUNDED_SPHERE: + # only the points inside the sphere are valid + mask = (x - 0.5).norm(dim=1) < 0.5 + x = x[mask] + indices = indices[mask] # voxel coordinates [0, 1]^3 -> world x = contract_inv( x,