Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hot fix #43

Merged
merged 3 commits into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
copyright = "2022, Ruilong"
author = "Ruilong"

release = "0.1.2"
version = "0.1.2"
release = "0.1.4"
version = "0.1.4"

# -- General configuration

Expand Down
8 changes: 6 additions & 2 deletions nerfacc/cuda/csrc/pybind.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ std::vector<torch::Tensor> rendering_forward(
torch::Tensor ends,
torch::Tensor sigmas,
float early_stop_eps,
float alpha_thre,
bool compression);

torch::Tensor rendering_backward(
Expand All @@ -17,7 +18,8 @@ torch::Tensor rendering_backward(
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas,
float early_stop_eps);
float early_stop_eps,
float alpha_thre);

std::vector<torch::Tensor> ray_aabb_intersect(
const torch::Tensor rays_o,
Expand Down Expand Up @@ -65,12 +67,14 @@ torch::Tensor rendering_alphas_backward(
torch::Tensor grad_weights,
torch::Tensor packed_info,
torch::Tensor alphas,
float early_stop_eps);
float early_stop_eps,
float alpha_thre);

std::vector<torch::Tensor> rendering_alphas_forward(
torch::Tensor packed_info,
torch::Tensor alphas,
float early_stop_eps,
float alpha_thre,
bool compression);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
Expand Down
38 changes: 33 additions & 5 deletions nerfacc/cuda/csrc/rendering.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ __global__ void rendering_forward_kernel(
const scalar_t *sigmas, // input density after activation
const scalar_t *alphas, // input alpha (opacity) values.
const scalar_t early_stop_eps, // transmittance threshold for early stop
const scalar_t alpha_thre, // alpha threshold for emtpy space
// outputs: should be all-zero initialized
int *num_steps, // the number of valid steps for each ray
scalar_t *weights, // the number rendering weights for each sample
Expand Down Expand Up @@ -51,8 +52,8 @@ __global__ void rendering_forward_kernel(

// accumulated rendering
scalar_t T = 1.f;
int j = 0;
for (; j < steps; ++j)
int cnt = 0;
for (int j = 0; j < steps; ++j)
{
if (T < early_stop_eps)
{
Expand All @@ -70,6 +71,11 @@ __global__ void rendering_forward_kernel(
scalar_t delta = ends[j] - starts[j];
alpha = 1.f - __expf(-sigmas[j] * delta);
}
if (alpha < alpha_thre)
{
// empty space
continue;
}
const scalar_t weight = alpha * T;
T *= (1.f - alpha);
if (weights != nullptr)
Expand All @@ -80,10 +86,11 @@ __global__ void rendering_forward_kernel(
{
compact_selector[j] = true;
}
cnt += 1;
}
if (num_steps != nullptr)
{
*num_steps = j;
*num_steps = cnt;
}
return;
}
Expand All @@ -97,6 +104,7 @@ __global__ void rendering_backward_kernel(
const scalar_t *sigmas, // input density after activation
const scalar_t *alphas, // input alpha (opacity) values.
const scalar_t early_stop_eps, // transmittance threshold for early stop
const scalar_t alpha_thre, // alpha threshold for emtpy space
const scalar_t *weights, // forward output
const scalar_t *grad_weights, // input gradients
// if alphas was given, we compute the gradients for alphas.
Expand Down Expand Up @@ -150,13 +158,23 @@ __global__ void rendering_backward_kernel(
{
// rendering with alpha
alpha = alphas[j];
if (alpha < alpha_thre)
{
// empty space
continue;
}
grad_alphas[j] = (grad_weights[j] * T - accum) / fmaxf(1.f - alpha, 1e-10f);
}
else
{
// rendering with density
scalar_t delta = ends[j] - starts[j];
alpha = 1.f - __expf(-sigmas[j] * delta);
if (alpha < alpha_thre)
{
// empty space
continue;
}
grad_sigmas[j] = (grad_weights[j] * T - accum) * delta;
}

Expand All @@ -171,6 +189,7 @@ std::vector<torch::Tensor> rendering_forward(
torch::Tensor ends,
torch::Tensor sigmas,
float early_stop_eps,
float alpha_thre,
bool compression)
{
DEVICE_GUARD(packed_info);
Expand Down Expand Up @@ -211,6 +230,7 @@ std::vector<torch::Tensor> rendering_forward(
sigmas.data_ptr<scalar_t>(),
nullptr, // alphas
early_stop_eps,
alpha_thre,
// outputs
num_steps.data_ptr<int>(),
nullptr,
Expand Down Expand Up @@ -238,6 +258,7 @@ std::vector<torch::Tensor> rendering_forward(
sigmas.data_ptr<scalar_t>(),
nullptr, // alphas
early_stop_eps,
alpha_thre,
// outputs
nullptr,
weights.data_ptr<scalar_t>(),
Expand All @@ -254,7 +275,8 @@ torch::Tensor rendering_backward(
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas,
float early_stop_eps)
float early_stop_eps,
float alpha_thre)
{
DEVICE_GUARD(packed_info);
const uint32_t n_rays = packed_info.size(0);
Expand All @@ -279,6 +301,7 @@ torch::Tensor rendering_backward(
sigmas.data_ptr<scalar_t>(),
nullptr, // alphas
early_stop_eps,
alpha_thre,
weights.data_ptr<scalar_t>(),
grad_weights.data_ptr<scalar_t>(),
// outputs
Expand All @@ -295,6 +318,7 @@ std::vector<torch::Tensor> rendering_alphas_forward(
torch::Tensor packed_info,
torch::Tensor alphas,
float early_stop_eps,
float alpha_thre,
bool compression)
{
DEVICE_GUARD(packed_info);
Expand Down Expand Up @@ -331,6 +355,7 @@ std::vector<torch::Tensor> rendering_alphas_forward(
nullptr, // sigmas
alphas.data_ptr<scalar_t>(),
early_stop_eps,
alpha_thre,
// outputs
num_steps.data_ptr<int>(),
nullptr,
Expand Down Expand Up @@ -358,6 +383,7 @@ std::vector<torch::Tensor> rendering_alphas_forward(
nullptr, // sigmas
alphas.data_ptr<scalar_t>(),
early_stop_eps,
alpha_thre,
// outputs
nullptr,
weights.data_ptr<scalar_t>(),
Expand All @@ -372,7 +398,8 @@ torch::Tensor rendering_alphas_backward(
torch::Tensor grad_weights,
torch::Tensor packed_info,
torch::Tensor alphas,
float early_stop_eps)
float early_stop_eps,
float alpha_thre)
{
DEVICE_GUARD(packed_info);
const uint32_t n_rays = packed_info.size(0);
Expand All @@ -397,6 +424,7 @@ torch::Tensor rendering_alphas_backward(
nullptr, // sigmas
alphas.data_ptr<scalar_t>(),
early_stop_eps,
alpha_thre,
weights.data_ptr<scalar_t>(),
grad_weights.data_ptr<scalar_t>(),
// outputs
Expand Down
4 changes: 3 additions & 1 deletion nerfacc/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def rendering(
t_ends: torch.Tensor,
# rendering options
early_stop_eps: float = 1e-4,
alpha_thre: float = 1e-2,
render_bkgd: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Render the rays through the radience field defined by `rgb_sigma_fn`.
Expand All @@ -33,6 +34,7 @@ def rendering(
t_starts: Per-sample start distance. Tensor with shape (n_samples, 1).
t_ends: Per-sample end distance. Tensor with shape (n_samples, 1).
early_stop_eps: Early stop threshold during trasmittance accumulation. Default: 1e-4.
alpha_thre: Alpha threshold for skipping empty space. Default: 0.0.
render_bkgd: Optional. Background color. Tensor with shape (3,).

Returns:
Expand Down Expand Up @@ -82,7 +84,7 @@ def rgb_sigma_fn(t_starts, t_ends, ray_indices):

# Rendering: compute weights and ray indices.
weights = render_weight_from_density(
packed_info, t_starts, t_ends, sigmas, early_stop_eps
packed_info, t_starts, t_ends, sigmas, early_stop_eps, alpha_thre
)

# Rendering: accumulate rgbs, opacities, and depths along the rays.
Expand Down
4 changes: 3 additions & 1 deletion nerfacc/ray_marching.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def ray_marching(
# sigma function for skipping invisible space
sigma_fn: Optional[Callable] = None,
early_stop_eps: float = 1e-4,
alpha_thre: float = 0.0,
# rendering options
near_plane: Optional[float] = None,
far_plane: Optional[float] = None,
Expand Down Expand Up @@ -140,6 +141,7 @@ def ray_marching(
function that takes in samples {t_starts (N, 1), t_ends (N, 1),
ray indices (N,)} and returns the post-activation density values (N, 1).
early_stop_eps: Early stop threshold for skipping invisible space. Default: 1e-4.
alpha_thre: Alpha threshold for skipping empty space. Default: 0.0.
near_plane: Optional. Near plane distance. If provided, it will be used
to clip t_min.
far_plane: Optional. Far plane distance. If provided, it will be used
Expand Down Expand Up @@ -272,7 +274,7 @@ def ray_marching(

# Compute visibility of the samples, and filter out invisible samples
visibility, packed_info_visible = render_visibility(
packed_info, alphas, early_stop_eps
packed_info, alphas, early_stop_eps, alpha_thre
)
t_starts, t_ends = t_starts[visibility], t_ends[visibility]
packed_info = packed_info_visible
Expand Down
Loading