Skip to content

Commit

Permalink
shader: add SoftDepthShader and HardDepthShader for rendering depth m…
Browse files Browse the repository at this point in the history
…aps (#36)

Summary:
X-link: fairinternal/pytorch3d#36

This adds two shaders for rendering depth maps for meshes. This is useful for structure from motion applications that learn depths based off of camera pair disparities.

There's two shaders, one hard which just returns the distances and then a second that does a cumsum on the probabilities of the points with a weighted sum. Areas that don't have any z faces are set to the zfar distance.

Output from this renderer is `[N, H, W]` since it's just depth no need for channels.

I haven't tested this in an ML model yet just in a notebook.

hard:
![hardzshader](https://user-images.githubusercontent.com/909104/170190363-ef662c97-0bd2-488c-8675-0557a3c7dd06.png)

soft:
![softzshader](https://user-images.githubusercontent.com/909104/170190365-65b08cd7-0c49-4119-803e-d33c1d8c676e.png)

Pull Request resolved: #1208

Reviewed By: bottler

Differential Revision: D36682194

Pulled By: d4l3k

fbshipit-source-id: 5d4e10c6fb0fff5427be4ddd3bd76305a7ccc1e2
  • Loading branch information
d4l3k authored and facebook-github-bot committed Jun 26, 2022
1 parent 0e4c53c commit 7e0146e
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 0 deletions.
68 changes: 68 additions & 0 deletions pytorch3d/renderer/mesh/shader.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,3 +353,71 @@ def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tenso
)

return images


class HardDepthShader(ShaderBase):
"""
Renders the Z distances of the closest face for each pixel. If no face is
found it returns the zfar value of the camera.
Output from this shader is [N, H, W, 1] since it's only depth.
To use the default values, simply initialize the shader with the desired
device e.g.
.. code-block::
shader = HardDepthShader(device=torch.device("cuda:0"))
"""

def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = super()._get_cameras(**kwargs)

zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0))
mask = fragments.pix_to_face < 0

zbuf = fragments.zbuf[..., 0].clone()
zbuf[mask] = zfar
return zbuf.unsqueeze(3)


class SoftDepthShader(ShaderBase):
"""
Renders the Z distances using an aggregate of the distances of each face
based off of the point distance. If no face is found it returns the zfar
value of the camera.
Output from this shader is [N, H, W, 1] since it's only depth.
To use the default values, simply initialize the shader with the desired
device e.g.
.. code-block::
shader = SoftDepthShader(device=torch.device("cuda:0"))
"""

def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = super()._get_cameras(**kwargs)

N, H, W, K = fragments.pix_to_face.shape
device = fragments.zbuf.device
mask = fragments.pix_to_face >= 0

zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0))

# Sigmoid probability map based on the distance of the pixel to the face.
prob_map = torch.sigmoid(-fragments.dists / self.blend_params.sigma) * mask

# append extra face for zfar
dists = torch.cat(
(fragments.zbuf, torch.ones((N, H, W, 1), device=device) * zfar), dim=3
)
probs = torch.cat((prob_map, torch.ones((N, H, W, 1), device=device)), dim=3)

# compute weighting based off of probabilities using cumsum
probs = probs.cumsum(dim=3)
probs = probs.clamp(max=1)
probs = probs.diff(dim=3, prepend=torch.zeros((N, H, W, 1), device=device))

return (probs * dists).sum(dim=3).unsqueeze(3)
4 changes: 4 additions & 0 deletions tests/test_shader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras
from pytorch3d.renderer.mesh.rasterizer import Fragments
from pytorch3d.renderer.mesh.shader import (
HardDepthShader,
HardFlatShader,
HardGouraudShader,
HardPhongShader,
SoftDepthShader,
SoftPhongShader,
SplatterPhongShader,
)
Expand All @@ -24,9 +26,11 @@
class TestShader(TestCaseMixin, unittest.TestCase):
def setUp(self):
self.shader_classes = [
HardDepthShader,
HardFlatShader,
HardGouraudShader,
HardPhongShader,
SoftDepthShader,
SoftPhongShader,
SplatterPhongShader,
]
Expand Down

0 comments on commit 7e0146e

Please sign in to comment.