diff --git a/pytorch3d/renderer/lighting.py b/pytorch3d/renderer/lighting.py index 3ca81ca77..a11ffbe40 100644 --- a/pytorch3d/renderer/lighting.py +++ b/pytorch3d/renderer/lighting.py @@ -292,6 +292,9 @@ class AmbientLights(TensorProperties): A light object representing the same color of light everywhere. By default, this is white, which effectively means lighting is not used in rendering. + + Unlike other lights this supports an arbitrary number of channels, not just 3 for RGB. + The ambient_color input determines the number of channels. """ def __init__(self, *, ambient_color=None, device: Device = "cpu") -> None: @@ -304,9 +307,11 @@ def __init__(self, *, ambient_color=None, device: Device = "cpu") -> None: device: Device (as str or torch.device) on which the tensors should be located The ambient_color if provided, should be - - 3 element tuple/list or list of lists - - torch tensor of shape (1, 3) - - torch tensor of shape (N, 3) + - tuple/list of C-element tuples of floats + - torch tensor of shape (1, C) + - torch tensor of shape (N, C) + where C is the number of channels and N is batch size. + For RGB, C is 3. """ if ambient_color is None: ambient_color = ((1.0, 1.0, 1.0),) @@ -317,10 +322,14 @@ def clone(self): return super().clone(other) def diffuse(self, normals, points) -> torch.Tensor: - return torch.zeros_like(points) + return self._zeros_channels(points) def specular(self, normals, points, camera_position, shininess) -> torch.Tensor: - return torch.zeros_like(points) + return self._zeros_channels(points) + + def _zeros_channels(self, points: torch.Tensor) -> torch.Tensor: + ch = self.ambient_color.shape[-1] + return torch.zeros(*points.shape[:-1], ch, device=points.device) def _validate_light_properties(obj) -> None: diff --git a/pytorch3d/renderer/materials.py b/pytorch3d/renderer/materials.py index 738808b72..27558ed8a 100644 --- a/pytorch3d/renderer/materials.py +++ b/pytorch3d/renderer/materials.py @@ -27,9 +27,9 @@ def __init__( ) -> None: """ Args: - ambient_color: RGB ambient reflectivity of the material - diffuse_color: RGB diffuse reflectivity of the material - specular_color: RGB specular reflectivity of the material + ambient_color: ambient reflectivity of the material + diffuse_color: diffuse reflectivity of the material + specular_color: specular reflectivity of the material shininess: The specular exponent for the material. This defines the focus of the specular highlight with a high value resulting in a concentrated highlight. Shininess values @@ -37,7 +37,8 @@ def __init__( device: Device (as str or torch.device) on which the tensors should be located ambient_color, diffuse_color and specular_color can be of shape - (1, 3) or (N, 3). shininess can be of shape (1) or (N). + (1, C) or (N, C) where C is typically 3 (for RGB). shininess can be of shape (1,) + or (N,). The colors and shininess are broadcast against each other so need to have either the same batch dimension or batch dimension = 1. @@ -49,11 +50,12 @@ def __init__( specular_color=specular_color, shininess=shininess, ) + C = self.ambient_color.shape[-1] for n in ["ambient_color", "diffuse_color", "specular_color"]: t = getattr(self, n) - if t.shape[-1] != 3: - msg = "Expected %s to have shape (N, 3); got %r" - raise ValueError(msg % (n, t.shape)) + if t.shape[-1] != C: + msg = "Expected %s to have shape (N, %d); got %r" + raise ValueError(msg % (n, C, t.shape)) if self.shininess.shape != torch.Size([self._N]): msg = "shininess should have shape (N); got %r" raise ValueError(msg % repr(self.shininess.shape)) diff --git a/tests/data/test_nd_sphere.png b/tests/data/test_nd_sphere.png new file mode 100644 index 000000000..cf2f9a9ea Binary files /dev/null and b/tests/data/test_nd_sphere.png differ diff --git a/tests/test_render_meshes.py b/tests/test_render_meshes.py index b47a78f2f..1dcd8e1cc 100644 --- a/tests/test_render_meshes.py +++ b/tests/test_render_meshes.py @@ -1236,3 +1236,81 @@ def test_cameras_kwarg(self): "test_simple_sphere_light_phong_%s.png" % cam_type.__name__, DATA_DIR ) self.assertClose(rgb, image_ref, atol=0.05) + + def test_nd_sphere(self): + """ + Test that the render can handle textures with more than 3 channels and + not just 3 channel RGB. + """ + torch.manual_seed(1) + device = torch.device("cuda:0") + C = 5 + WHITE = ((1.0,) * C,) + BLACK = ((0.0,) * C,) + + # Init mesh + sphere_mesh = ico_sphere(5, device) + verts_padded = sphere_mesh.verts_padded() + faces_padded = sphere_mesh.faces_padded() + feats = torch.ones(*verts_padded.shape[:-1], C, device=device) + n_verts = feats.shape[1] + # make some non-uniform pattern + feats *= torch.arange(0, 10, step=10 / n_verts, device=device).unsqueeze(1) + textures = TexturesVertex(verts_features=feats) + sphere_mesh = Meshes(verts=verts_padded, faces=faces_padded, textures=textures) + + # No elevation or azimuth rotation + R, T = look_at_view_transform(2.7, 0.0, 0.0) + + cameras = PerspectiveCameras(device=device, R=R, T=T) + + # Init shader settings + materials = Materials( + device=device, + ambient_color=WHITE, + diffuse_color=WHITE, + specular_color=WHITE, + ) + lights = AmbientLights( + device=device, + ambient_color=WHITE, + ) + lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None] + + raster_settings = RasterizationSettings( + image_size=512, blur_radius=0.0, faces_per_pixel=1 + ) + rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings) + blend_params = BlendParams( + 1e-4, + 1e-4, + background_color=BLACK[0], + ) + + # only test HardFlatShader since that's the only one that makes + # sense for classification + shader = HardFlatShader( + lights=lights, + cameras=cameras, + materials=materials, + blend_params=blend_params, + ) + renderer = MeshRenderer(rasterizer=rasterizer, shader=shader) + images = renderer(sphere_mesh) + + self.assertEqual(images.shape[-1], C + 1) + self.assertClose(images.amax(), torch.tensor(10.0), atol=0.01) + self.assertClose(images.amin(), torch.tensor(0.0), atol=0.01) + + # grab last 3 color channels + rgb = (images[0, ..., C - 3 : C] / 10).squeeze().cpu() + filename = "test_nd_sphere.png" + + if DEBUG: + debug_filename = "DEBUG_%s" % filename + Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / debug_filename + ) + + image_ref = load_rgb_image(filename, DATA_DIR) + self.assertClose(rgb, image_ref, atol=0.05)