Skip to content

Commit

Permalink
renderer: add support for rendering high dimensional textures for cla…
Browse files Browse the repository at this point in the history
…ssification/segmentation use cases (#1248)

Summary:
For 3D segmentation problems it's really useful to be able to train the models from multiple viewpoints using Pytorch3D as the renderer. Currently due to hardcoded assumptions in a few spots the mesh renderer only supports rendering RGB (3 dimensional) data. You can encode the classification information as 3 channel data but if you have more than 3 classes you're out of luck.

This relaxes the assumptions to make rendering semantic classes work with `HardFlatShader` and `AmbientLights` with no diffusion/specular. The other shaders/lights don't make any sense for classification since they mutate the texture values in some way.

This only requires changes in `Materials` and `AmbientLights`. The bulk of the code is the unit test.

Pull Request resolved: #1248

Test Plan: Added unit test that renders a 5 dimensional texture and compare dimensions 2-5 to a stored picture.

Reviewed By: bottler

Differential Revision: D37764610

Pulled By: d4l3k

fbshipit-source-id: 031895724d9318a6f6bab5b31055bb3f438176a5
  • Loading branch information
d4l3k authored and facebook-github-bot committed Jul 12, 2022
1 parent aa8b03f commit 8d10ba5
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 12 deletions.
19 changes: 14 additions & 5 deletions pytorch3d/renderer/lighting.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,9 @@ class AmbientLights(TensorProperties):
A light object representing the same color of light everywhere.
By default, this is white, which effectively means lighting is
not used in rendering.
Unlike other lights this supports an arbitrary number of channels, not just 3 for RGB.
The ambient_color input determines the number of channels.
"""

def __init__(self, *, ambient_color=None, device: Device = "cpu") -> None:
Expand All @@ -304,9 +307,11 @@ def __init__(self, *, ambient_color=None, device: Device = "cpu") -> None:
device: Device (as str or torch.device) on which the tensors should be located
The ambient_color if provided, should be
- 3 element tuple/list or list of lists
- torch tensor of shape (1, 3)
- torch tensor of shape (N, 3)
- tuple/list of C-element tuples of floats
- torch tensor of shape (1, C)
- torch tensor of shape (N, C)
where C is the number of channels and N is batch size.
For RGB, C is 3.
"""
if ambient_color is None:
ambient_color = ((1.0, 1.0, 1.0),)
Expand All @@ -317,10 +322,14 @@ def clone(self):
return super().clone(other)

def diffuse(self, normals, points) -> torch.Tensor:
return torch.zeros_like(points)
return self._zeros_channels(points)

def specular(self, normals, points, camera_position, shininess) -> torch.Tensor:
return torch.zeros_like(points)
return self._zeros_channels(points)

def _zeros_channels(self, points: torch.Tensor) -> torch.Tensor:
ch = self.ambient_color.shape[-1]
return torch.zeros(*points.shape[:-1], ch, device=points.device)


def _validate_light_properties(obj) -> None:
Expand Down
16 changes: 9 additions & 7 deletions pytorch3d/renderer/materials.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,18 @@ def __init__(
) -> None:
"""
Args:
ambient_color: RGB ambient reflectivity of the material
diffuse_color: RGB diffuse reflectivity of the material
specular_color: RGB specular reflectivity of the material
ambient_color: ambient reflectivity of the material
diffuse_color: diffuse reflectivity of the material
specular_color: specular reflectivity of the material
shininess: The specular exponent for the material. This defines
the focus of the specular highlight with a high value
resulting in a concentrated highlight. Shininess values
can range from 0-1000.
device: Device (as str or torch.device) on which the tensors should be located
ambient_color, diffuse_color and specular_color can be of shape
(1, 3) or (N, 3). shininess can be of shape (1) or (N).
(1, C) or (N, C) where C is typically 3 (for RGB). shininess can be of shape (1,)
or (N,).
The colors and shininess are broadcast against each other so need to
have either the same batch dimension or batch dimension = 1.
Expand All @@ -49,11 +50,12 @@ def __init__(
specular_color=specular_color,
shininess=shininess,
)
C = self.ambient_color.shape[-1]
for n in ["ambient_color", "diffuse_color", "specular_color"]:
t = getattr(self, n)
if t.shape[-1] != 3:
msg = "Expected %s to have shape (N, 3); got %r"
raise ValueError(msg % (n, t.shape))
if t.shape[-1] != C:
msg = "Expected %s to have shape (N, %d); got %r"
raise ValueError(msg % (n, C, t.shape))
if self.shininess.shape != torch.Size([self._N]):
msg = "shininess should have shape (N); got %r"
raise ValueError(msg % repr(self.shininess.shape))
Expand Down
Binary file added tests/data/test_nd_sphere.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
78 changes: 78 additions & 0 deletions tests/test_render_meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,3 +1236,81 @@ def test_cameras_kwarg(self):
"test_simple_sphere_light_phong_%s.png" % cam_type.__name__, DATA_DIR
)
self.assertClose(rgb, image_ref, atol=0.05)

def test_nd_sphere(self):
"""
Test that the render can handle textures with more than 3 channels and
not just 3 channel RGB.
"""
torch.manual_seed(1)
device = torch.device("cuda:0")
C = 5
WHITE = ((1.0,) * C,)
BLACK = ((0.0,) * C,)

# Init mesh
sphere_mesh = ico_sphere(5, device)
verts_padded = sphere_mesh.verts_padded()
faces_padded = sphere_mesh.faces_padded()
feats = torch.ones(*verts_padded.shape[:-1], C, device=device)
n_verts = feats.shape[1]
# make some non-uniform pattern
feats *= torch.arange(0, 10, step=10 / n_verts, device=device).unsqueeze(1)
textures = TexturesVertex(verts_features=feats)
sphere_mesh = Meshes(verts=verts_padded, faces=faces_padded, textures=textures)

# No elevation or azimuth rotation
R, T = look_at_view_transform(2.7, 0.0, 0.0)

cameras = PerspectiveCameras(device=device, R=R, T=T)

# Init shader settings
materials = Materials(
device=device,
ambient_color=WHITE,
diffuse_color=WHITE,
specular_color=WHITE,
)
lights = AmbientLights(
device=device,
ambient_color=WHITE,
)
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]

raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1
)
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
blend_params = BlendParams(
1e-4,
1e-4,
background_color=BLACK[0],
)

# only test HardFlatShader since that's the only one that makes
# sense for classification
shader = HardFlatShader(
lights=lights,
cameras=cameras,
materials=materials,
blend_params=blend_params,
)
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
images = renderer(sphere_mesh)

self.assertEqual(images.shape[-1], C + 1)
self.assertClose(images.amax(), torch.tensor(10.0), atol=0.01)
self.assertClose(images.amin(), torch.tensor(0.0), atol=0.01)

# grab last 3 color channels
rgb = (images[0, ..., C - 3 : C] / 10).squeeze().cpu()
filename = "test_nd_sphere.png"

if DEBUG:
debug_filename = "DEBUG_%s" % filename
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / debug_filename
)

image_ref = load_rgb_image(filename, DATA_DIR)
self.assertClose(rgb, image_ref, atol=0.05)

0 comments on commit 8d10ba5

Please sign in to comment.