Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy load ContractionType #39

Merged
merged 1 commit into from
Oct 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions nerfacc/contraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ class ContractionType(Enum):
UN_BOUNDED_TANH = 1
UN_BOUNDED_SPHERE = 2

def to_cpp_version(self):
"""Convert to the C++ version of the enum class.

Returns:
The C++ version of the enum class.

"""
return _C.ContractionTypeGetter(self.value)


@torch.no_grad()
def contract(
Expand All @@ -65,7 +74,7 @@ def contract(
Returns:
torch.Tensor: Contracted points ([0, 1]^3).
"""
ctype = _C.ContractionType(type.value)
ctype = type.to_cpp_version()
return _C.contract(x.contiguous(), roi.contiguous(), ctype)


Expand All @@ -85,5 +94,5 @@ def contract_inv(
Returns:
torch.Tensor: Un-contracted points.
"""
ctype = _C.ContractionType(type.value)
ctype = type.to_cpp_version()
return _C.contract_inv(x.contiguous(), roi.contiguous(), ctype)
12 changes: 1 addition & 11 deletions nerfacc/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,7 @@ def call_cuda(*args, **kwargs):
return call_cuda


def _make_lazy_cuda_attribute(name: str) -> Any:
# pylint: disable=import-outside-toplevel
from ._backend import _C

if _C is None:
return None
else:
return getattr(_C, name)


ContractionType = _make_lazy_cuda_attribute("ContractionType")
ContractionTypeGetter = _make_lazy_cuda_func("ContractionType")
contract = _make_lazy_cuda_func("contract")
contract_inv = _make_lazy_cuda_func("contract_inv")

Expand Down
5 changes: 3 additions & 2 deletions nerfacc/ray_marching.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch import Tensor

import nerfacc.cuda as _C
from nerfacc.contraction import ContractionType

from .grid import Grid
from .vol_rendering import render_visibility
Expand Down Expand Up @@ -231,7 +232,7 @@ def ray_marching(
if grid is not None:
grid_roi_aabb = grid.roi_aabb
grid_binary = grid.binary
contraction_type = _C.ContractionType(grid.contraction_type.value)
contraction_type = grid.contraction_type.to_cpp_version()
else:
grid_roi_aabb = torch.tensor(
[-1e10, -1e10, -1e10, 1e10, 1e10, 1e10],
Expand All @@ -241,7 +242,7 @@ def ray_marching(
grid_binary = torch.ones(
[1, 1, 1], dtype=torch.bool, device=rays_o.device
)
contraction_type = _C.ContractionType.AABB
contraction_type = ContractionType.AABB.to_cpp_version()

# marching with grid-based skipping
packed_info, t_starts, t_ends = _C.ray_marching(
Expand Down