Skip to content

Commit

Permalink
upd comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Linyou committed Apr 21, 2023
1 parent ba900e5 commit f62b72d
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 34 deletions.
7 changes: 4 additions & 3 deletions nerfacc/cuda/csrc/grid.cu
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ __global__ void traverse_grids_kernel(
} else {
chunk_start = tid * max_samples_per_ray * 2;
chunk_start_bin = tid * max_samples_per_ray;
// ray_mask_id stores the original ray id for each test time ray.
tid_t = ray_mask_id[tid];
}

Expand Down Expand Up @@ -191,22 +192,22 @@ __global__ void traverse_grids_kernel(
if (!first_pass) { // left side of the intervel
int64_t idx = chunk_start + n_intervals;
intervals.vals[idx] = t_last;
intervals.ray_indices[idx] = tid_t;
intervals.ray_indices[idx] = tid;
intervals.is_left[idx] = true;
}
n_intervals++;
if (!first_pass) { // right side of the intervel
int64_t idx = chunk_start + n_intervals;
intervals.vals[idx] = t_next;
intervals.ray_indices[idx] = tid_t;
intervals.ray_indices[idx] = tid;
intervals.is_right[idx] = true;
}
n_intervals++;
} else {
if (!first_pass) { // right side of the intervel
int64_t idx = chunk_start + n_intervals;
intervals.vals[idx] = t_next;
intervals.ray_indices[idx] = tid_t;
intervals.ray_indices[idx] = tid;
intervals.is_left[idx - 1] = true;
intervals.is_right[idx] = true;
}
Expand Down
34 changes: 17 additions & 17 deletions nerfacc/estimators/occ_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,29 +256,29 @@ def update_every_n_steps(
def mark_invisible_cells(
self,
K: Tensor,
poses: Tensor,
c2w: Tensor,
width: int,
height: int,
near_plane: float = 0.0,
chunk: int =32**3
) -> None:
"""Mark the cells that aren't covered by the cameras with density -1
only executed once before training starts.
Note:
This code is adapted from: https://github.com/kwea123/ngp_pl/blob/master/models/networks.py
Args:
K: Camera intrinsics of shape (3, 3).
c2w: Camera to world poses of shape (N, 3, 4).
width: Image width in pixels
height: Image height in pixels
near_plane: Near plane distance
chunk: The chunk size to split the cells (to avoid OOM)
"""
This code is adapted from: https://github.com/kwea123/ngp_pl/blob/master/models/networks.py
Mark the cells that aren't covered by the cameras with density -1
only executed once before training starts
Inputs:
K: (3, 3) camera intrinsics
poses: (N, 3, 4) camera to world poses
width: image width in pixels
height: image height in pixels
near_plane: near plane distance
chunk: the chunk size to split the cells (to avoid OOM)
"""
N_cams = poses.shape[0]
w2c_R = poses[:, :3, :3].transpose(2, 1) # (N_cams, 3, 3)
w2c_T = -w2c_R@poses[:, :3, 3:] # (N_cams, 3, 1)
N_cams = c2w.shape[0]
w2c_R = c2w[:, :3, :3].transpose(2, 1) # (N_cams, 3, 3)
w2c_T = -w2c_R@c2w[:, :3, 3:] # (N_cams, 3, 1)

lvl_indices = self._get_all_cells()
for lvl, indices in enumerate(lvl_indices):
Expand Down
6 changes: 6 additions & 0 deletions nerfacc/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,16 @@ def traverse_grids(
step_size: Optional. Step size for ray traversal. Default to 1e-3.
cone_angle: Optional. Cone angle for linearly-increased step size. 0. means
constant step size. Default: 0.0.
max_samples_per_ray: Optional. Maximum number of samples per ray. Default to 4096.
ray_mask_id: Optional. (n_rays_chunk,) Ray mask id for each ray. Default to None.
t_sorted: Optional. (n_rays, n_grids) Pre-computed sorted t values for each ray-grid pair. Default to None.
t_indices: Optional. (n_rays, n_grids) Pre-computed sorted t indices for each ray-grid pair. Default to None.
hits: Optional. (n_rays, n_grids) Pre-computed hit flags for each ray-grid pair. Default to None.
Returns:
A :class:`RayIntervals` object containing the intervals of the ray traversal, and
a :class:`RaySamples` object containing the samples within each interval.
t :class:`Tensor` of shape (n_rays,) containing the terminated t values for each ray, only return `Tensor` when ray_mask_id is provided, otherwise return None.
"""

if near_planes is None:
Expand Down
37 changes: 23 additions & 14 deletions tests/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def test_traverse_grids_test_mode():
torch.manual_seed(42)
n_rays = 10
n_aabbs = 4
max_samples_per_ray = 100

ray_mask_id = torch.arange(n_rays, device=device)

Expand All @@ -89,6 +88,21 @@ def test_traverse_grids_test_mode():

binaries = torch.rand((n_aabbs, 32, 32, 32), device=device) > 0.5

# ref results: train mode
intervals, samples, _ = traverse_grids(rays_o, rays_d, binaries, aabbs)

ray_indices = samples.ray_indices
t_starts = intervals.vals[intervals.is_left]
t_ends = intervals.vals[intervals.is_right]
positions = (
rays_o[ray_indices]
+ rays_d[ray_indices] * (t_starts + t_ends)[:, None] / 2.0
)
occs, selector = _query(positions, binaries, base_aabb)
assert occs.all(), occs.float().mean()
assert selector.all(), selector.float().mean()

# test mode
t_mins, t_maxs, hits = ray_aabb_intersect(rays_o, rays_d, aabbs)
t_sorted, t_indices = torch.sort(torch.cat([t_mins, t_maxs], -1), -1)

Expand All @@ -97,24 +111,20 @@ def test_traverse_grids_test_mode():
rays_d,
binaries,
aabbs,
max_samples_per_ray=max_samples_per_ray,
max_samples_per_ray=4096,
ray_mask_id=ray_mask_id,
t_sorted=t_sorted,
t_indices=t_indices,
hits=hits
)
ray_indices_test = samples.ray_indices[samples.is_valid]
t_starts_test = intervals.vals[intervals.is_left]
t_ends_test = intervals.vals[intervals.is_right]

assert (ray_indices == ray_indices_test).all()
assert (t_starts == t_starts_test).all()
assert (t_ends == t_ends_test).all()

ray_indices = samples.ray_indices
t_starts = intervals.vals[intervals.is_left]
t_ends = intervals.vals[intervals.is_right]
positions = (
rays_o[ray_indices]
+ rays_d[ray_indices] * (t_starts + t_ends)[:, None] / 2.0
)
occs, selector = _query(positions, binaries, base_aabb)
assert occs.all(), occs.float().mean()
assert selector.all(), selector.float().mean()
assert positions.shape[0] == max_samples_per_ray * n_rays


@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
Expand Down Expand Up @@ -192,7 +202,6 @@ def test_sampling_with_min_max_distances():
def test_mark_invisible_cells():
from nerfacc import OccGridEstimator

torch.manual_seed(42)
levels = 4
resolution = 32
width = 100
Expand Down

0 comments on commit f62b72d

Please sign in to comment.