Skip to content

Commit

Permalink
images for debugging TexturesUV
Browse files Browse the repository at this point in the history
Summary: New methods to directly plot a TexturesUV map with its used points, using PIL and matplotlib.

Reviewed By: gkioxari

Differential Revision: D23782968

fbshipit-source-id: 692970857b5be13a35a3175dc82ac03963a73555
  • Loading branch information
bottler authored and facebook-github-bot committed Oct 27, 2020
1 parent b149bbf commit aa4cc0a
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 4 deletions.
27 changes: 23 additions & 4 deletions docs/tutorials/render_textured_meshes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
"\n",
"# Data structures and functions for rendering\n",
"from pytorch3d.structures import Meshes\n",
"from pytorch3d.vis import AxisArgs, plot_batch_individually, plot_scene\n",
"from pytorch3d.vis import AxisArgs, plot_batch_individually, plot_scene, texturesuv_image_matplotlib\n",
"from pytorch3d.renderer import (\n",
" look_at_view_transform,\n",
" FoVPerspectiveCameras, \n",
Expand Down Expand Up @@ -236,8 +236,7 @@
"obj_filename = os.path.join(DATA_DIR, \"cow_mesh/cow.obj\")\n",
"\n",
"# Load obj file\n",
"mesh = load_objs_as_meshes([obj_filename], device=device)\n",
"texture_image=mesh.textures.maps_padded()"
"mesh = load_objs_as_meshes([obj_filename], device=device)"
]
},
{
Expand Down Expand Up @@ -265,9 +264,29 @@
"outputs": [],
"source": [
"plt.figure(figsize=(7,7))\n",
"texture_image=mesh.textures.maps_padded()\n",
"plt.imshow(texture_image.squeeze().cpu().numpy())\n",
"plt.grid(\"off\");\n",
"plt.axis('off');"
"plt.axis(\"off\");"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"PyTorch3D has a built-in way to view the texture map with matplotlib along with the points on the map corresponding to vertices. There is also a method, texturesuv_image_PIL, to get a similar image which can be saved to a file."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(7,7))\n",
"texturesuv_image_matplotlib(mesh.textures, subsample=None)\n",
"plt.grid(\"off\");\n",
"plt.axis(\"off\");"
]
},
{
Expand Down
36 changes: 36 additions & 0 deletions pytorch3d/renderer/mesh/textures.py
Original file line number Diff line number Diff line change
Expand Up @@ -1174,6 +1174,42 @@ def join_scene(self) -> "TexturesUV":
padding_mode=self.padding_mode,
)

def centers_for_image(self, index):
"""
Return the locations in the texture map which correspond to the given
verts_uvs, for one of the meshes. This is potentially useful for
visualizing the data. See the texturesuv_image_matplotlib and
texturesuv_image_PIL functions.
Args:
index: batch index of the mesh whose centers to return.
Returns:
centers: coordinates of points in the texture image
- a FloatTensor of shape (V,2)
"""
if self._N != 1:
raise ValueError(
"This function only supports plotting textures for one mesh."
)
texture_image = self.maps_padded()
verts_uvs = self.verts_uvs_list()[index][None]
_, H, W, _3 = texture_image.shape
coord1 = torch.arange(W).expand(H, W)
coord2 = torch.arange(H)[:, None].expand(H, W)
coords = torch.stack([coord1, coord2])[None]
with torch.no_grad():
# Get xy cartesian coordinates based on the uv coordinates
centers = F.grid_sample(
torch.flip(coords.to(texture_image), [2]),
# Convert from [0, 1] -> [-1, 1] range expected by grid sample
verts_uvs[:, None] * 2.0 - 1,
align_corners=self.align_corners,
padding_mode=self.padding_mode,
).cpu()
centers = centers[0, :, 0].T
return centers


class TexturesVertex(TexturesBase):
def __init__(
Expand Down
1 change: 1 addition & 0 deletions pytorch3d/vis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

from .plotly_vis import AxisArgs, Lighting, plot_batch_individually, plot_scene
from .texture_vis import texturesuv_image_matplotlib, texturesuv_image_PIL


__all__ = [k for k in globals().keys() if not k.startswith("_")]
104 changes: 104 additions & 0 deletions pytorch3d/vis/texture_vis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Optional

import numpy as np
from PIL import Image, ImageDraw
from pytorch3d.renderer.mesh import TexturesUV


def texturesuv_image_matplotlib(
texture: TexturesUV,
*,
texture_index: int = 0,
radius: float = 1,
color=(1.0, 0.0, 0.0),
subsample: Optional[int] = 10000,
origin: str = "upper",
):
"""
Plot the texture image for one element of a TexturesUV with
matplotlib together with verts_uvs positions circled.
In particular a value in verts_uvs which is never referenced
in faces_uvs will still be plotted.
This is for debugging purposes, e.g. to align the map with
the uv coordinates. In particular, matplotlib
is used which is not an official dependency of PyTorch3D.
Args:
texture: a TexturesUV object with one mesh
texture_index: index in the batch to plot
radius: plotted circle radius in pixels
color: any matplotlib-understood color for the circles.
subsample: if not None, number of points to plot.
Otherwise all points are plotted.
origin: "upper" or "lower" like matplotlib.imshow
"""

import matplotlib.pyplot as plt
from matplotlib.patches import Circle

texture_image = texture.maps_padded()
centers = texture.centers_for_image(index=texture_index).numpy()

ax = plt.gca()
ax.imshow(texture_image[texture_index].detach().cpu().numpy(), origin=origin)

n_points = centers.shape[0]
if subsample is None or n_points <= subsample:
indices = range(n_points)
else:
indices = np.random.choice(n_points, subsample, replace=False)
for i in indices:
# setting clip_on=False makes it obvious when
# we have UV coordinates outside the correct range
ax.add_patch(Circle(centers[i], radius, color=color, clip_on=False))


def texturesuv_image_PIL(
texture: TexturesUV,
*,
texture_index: int = 0,
radius: float = 1,
color="red",
subsample: Optional[int] = 10000,
):
"""
Return a PIL image of the texture image of one element of the batch
from a TexturesUV, together with the verts_uvs positions circled.
In particular a value in verts_uvs which is never referenced
in faces_uvs will still be plotted.
This is for debugging purposes, e.g. to align the map with
the uv coordinates. In particular, matplotlib
is used which is not an official dependency of PyTorch3D.
Args:
texture: a TexturesUV object with one mesh
texture_index: index in the batch to plot
radius: plotted circle radius in pixels
color: any PIL-understood color for the circles.
subsample: if not None, number of points to plot.
Otherwise all points are plotted.
Returns:
PIL Image object.
"""

centers = texture.centers_for_image(index=texture_index).numpy()
texture_image = texture.maps_padded()
texture_array = (texture_image[texture_index] * 255).cpu().numpy().astype(np.uint8)

image = Image.fromarray(texture_array)
draw = ImageDraw.Draw(image)

n_points = centers.shape[0]
if subsample is None or n_points <= subsample:
indices = range(n_points)
else:
indices = np.random.choice(n_points, subsample, replace=False)

for i in indices:
x = centers[i][0]
y = centers[i][1]
draw.ellipse([(x - radius, y - radius), (x + radius, y + radius)], fill=color)

return image
Binary file added tests/data/texturesuv_debug.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 28 additions & 0 deletions tests/test_texturing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@


import unittest
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
from common_testing import TestCaseMixin
from PIL import Image
from pytorch3d.renderer.mesh.rasterizer import Fragments
from pytorch3d.renderer.mesh.textures import (
TexturesAtlas,
Expand All @@ -15,9 +18,14 @@
pack_rectangles,
)
from pytorch3d.structures import Meshes, list_to_packed, packed_to_list
from pytorch3d.vis import texturesuv_image_PIL
from test_meshes import TestMeshes


DEBUG = False
DATA_DIR = Path(__file__).resolve().parent / "data"


def tryindex(self, index, tex, meshes, source):
tex2 = tex[index]
meshes2 = meshes[index]
Expand Down Expand Up @@ -471,6 +479,10 @@ def test_getitem(self):


class TestTexturesUV(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(42)

def test_sample_textures_uv(self):
barycentric_coords = torch.tensor(
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
Expand Down Expand Up @@ -821,6 +833,22 @@ def test_getitem(self):
tryindex(self, index, tex, meshes, source)
tryindex(self, [2, 4], tex, meshes, source)

def test_png_debug(self):
maps = torch.rand(size=(1, 256, 128, 3)) * torch.tensor([0.8, 1, 0.8])
verts_uvs = torch.rand(size=(1, 20, 2))
faces_uvs = torch.zeros(size=(1, 0, 3), dtype=torch.int64)
tex = TexturesUV(maps=maps, faces_uvs=faces_uvs, verts_uvs=verts_uvs)

image = texturesuv_image_PIL(tex, radius=3)
image_out = np.array(image)
if DEBUG:
image.save(DATA_DIR / "texturesuv_debug_.png")

with Image.open(DATA_DIR / "texturesuv_debug.png") as image_ref_file:
image_ref = np.array(image_ref_file)

self.assertClose(image_out, image_ref)


class TestRectanglePacking(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
Expand Down

0 comments on commit aa4cc0a

Please sign in to comment.