Skip to content

Commit

Permalink
Fix texture atlas for objs which only have material properties
Browse files Browse the repository at this point in the history
Summary:
Fix for GitHub issue #381.

The example mesh provided in the issue only had material properties but no texture image. The current implementation of texture atlassing generated an atlas using both the material properties and the texture image but only worked if there was a texture image and associated vertex uv coordinates. I have now modified the texture atlas creation so that it doesn't require an image and can work with materials which only have material properties.

Reviewed By: gkioxari

Differential Revision: D24153068

fbshipit-source-id: 63e9d325db09a84b336b83369d5342ce588a9932
  • Loading branch information
nikhilaravi authored and facebook-github-bot committed Oct 7, 2020
1 parent 5d65a0c commit f5383a7
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 37 deletions.
55 changes: 31 additions & 24 deletions pytorch3d/io/mtl_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def make_mesh_texture_atlas(
material_properties: Dict,
texture_images: Dict,
face_material_names,
faces_verts_uvs: torch.Tensor,
faces_uvs: torch.Tensor,
verts_uvs: torch.Tensor,
texture_size: int,
texture_wrap: Optional[str],
) -> torch.Tensor:
Expand All @@ -31,8 +32,9 @@ def make_mesh_texture_atlas(
face_material_names: numpy array of the material name corresponding to each
face. Faces which don't have an associated material will be an empty string.
For these faces, a uniform white texture is assigned.
faces_verts_uvs: LongTensor of shape (F, 3, 2) giving the uv coordinates for each
vertex in the face.
faces_uvs: LongTensor of shape (F, 3,) giving the index into the verts_uvs for
each face in the mesh.
verts_uvs: FloatTensor of shape (V, 2) giving the uv coordinates for each vertex.
texture_size: the resolution of the per face texture map returned by this function.
Each face will have a texture map of shape (texture_size, texture_size, 3).
texture_wrap: string, one of ["repeat", "clamp", None]
Expand All @@ -47,50 +49,55 @@ def make_mesh_texture_atlas(
"""
# Create an R x R texture map per face in the mesh
R = texture_size
F = faces_verts_uvs.shape[0]
F = faces_uvs.shape[0]

# Initialize the per face texture map to a white color.
# TODO: allow customization of this base color?
# pyre-fixme[16]: `Tensor` has no attribute `new_ones`.
atlas = faces_verts_uvs.new_ones(size=(F, R, R, 3))
atlas = torch.ones(size=(F, R, R, 3), dtype=torch.float32, device=faces_uvs.device)

# Check for empty materials.
if not material_properties and not texture_images:
return atlas

# Iterate through the material properties - not
# all materials have texture images so this is
# done first separately to the texture interpolation.
for material_name, props in material_properties.items():
# Bool to indicate which faces use this texture map.
faces_material_ind = torch.from_numpy(face_material_names == material_name).to(
faces_uvs.device
)
if faces_material_ind.sum() > 0:
# For these faces, update the base color to the
# diffuse material color.
if "diffuse_color" not in props:
continue
atlas[faces_material_ind, ...] = props["diffuse_color"][None, :]

# If there are vertex texture coordinates, create an (F, 3, 2)
# tensor of the vertex textures per face.
faces_verts_uvs = verts_uvs[faces_uvs] if len(verts_uvs) > 0 else None

# Some meshes only have material properties and no texture image.
# In this case, return the atlas here.
if faces_verts_uvs is None:
return atlas

if texture_wrap == "repeat":
# If texture uv coordinates are outside the range [0, 1] follow
# the convention GL_REPEAT in OpenGL i.e the integer part of the coordinate
# will be ignored and a repeating pattern is formed.
# Shapenet data uses this format see:
# https://shapenet.org/qaforum/index.php?qa=15&qa_1=why-is-the-texture-coordinate-in-the-obj-file-not-in-the-range # noqa: B950
# pyre-fixme[16]: `ByteTensor` has no attribute `any`.
if (faces_verts_uvs > 1).any() or (faces_verts_uvs < 0).any():
msg = "Texture UV coordinates outside the range [0, 1]. \
The integer part will be ignored to form a repeating pattern."
warnings.warn(msg)
# pyre-fixme[9]: faces_verts_uvs has type `Tensor`; used as `int`.
# pyre-fixme[58]: `%` is not supported for operand types `Tensor` and `int`.
faces_verts_uvs = faces_verts_uvs % 1
elif texture_wrap == "clamp":
# Clamp uv coordinates to the [0, 1] range.
faces_verts_uvs = faces_verts_uvs.clamp(0.0, 1.0)

# Iterate through the material properties - not
# all materials have texture images so this has to be
# done separately to the texture interpolation.
for material_name, props in material_properties.items():
# Bool to indicate which faces use this texture map.
faces_material_ind = torch.from_numpy(face_material_names == material_name).to(
faces_verts_uvs.device
)
if faces_material_ind.sum() > 0:
# For these faces, update the base color to the
# diffuse material color.
if "diffuse_color" not in props:
continue
atlas[faces_material_ind, ...] = props["diffuse_color"][None, :]

# Iterate through the materials used in this mesh. Update the
# texture atlas for the faces which use this material.
# Faces without texture are white.
Expand Down
23 changes: 10 additions & 13 deletions pytorch3d/io/obj_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,19 +533,16 @@ def _load_obj(
face_material_names = np.array(material_names)[idx] # (F,)
face_material_names[idx == -1] = ""

if len(verts_uvs) > 0:
# Get the uv coords for each vert in each face
faces_verts_uvs = verts_uvs[faces_textures_idx] # (F, 3, 2)

# Construct the atlas.
texture_atlas = make_mesh_texture_atlas(
material_colors,
texture_images,
face_material_names,
faces_verts_uvs,
texture_atlas_size,
texture_wrap,
)
# Construct the atlas.
texture_atlas = make_mesh_texture_atlas(
material_colors,
texture_images,
face_material_names,
faces_textures_idx,
verts_uvs,
texture_atlas_size,
texture_wrap,
)

faces = _Faces(
verts_idx=faces_verts_idx,
Expand Down
7 changes: 7 additions & 0 deletions tests/data/obj_mtl_no_image/model.mtl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Material Count: 1

newmtl material_1
Ns 96.078431
Ka 0.000000 0.000000 0.000000
Kd 0.500000 0.000000 0.000000
Ks 0.500000 0.500000 0.500000
10 changes: 10 additions & 0 deletions tests/data/obj_mtl_no_image/model.obj
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

mtllib model.mtl

v 0.1 0.2 0.3
v 0.2 0.3 0.4
v 0.3 0.4 0.5
v 0.4 0.5 0.6
usemtl material_1
f 1 2 3
f 1 2 4
29 changes: 29 additions & 0 deletions tests/test_obj_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,35 @@ def test_load_mtl_fail(self):
self.assertTrue(aux.normals is None)
self.assertTrue(aux.verts_uvs is None)

def test_load_obj_mlt_no_image(self):
DATA_DIR = Path(__file__).resolve().parent / "data"
obj_filename = "obj_mtl_no_image/model.obj"
filename = os.path.join(DATA_DIR, obj_filename)
R = 8
verts, faces, aux = load_obj(
filename,
load_textures=True,
create_texture_atlas=True,
texture_atlas_size=R,
texture_wrap=None,
)

expected_verts = torch.tensor(
[[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]],
dtype=torch.float32,
)
expected_faces = torch.tensor([[0, 1, 2], [0, 1, 3]], dtype=torch.int64)
self.assertTrue(torch.allclose(verts, expected_verts))
self.assertTrue(torch.allclose(faces.verts_idx, expected_faces))

# Check that the material diffuse color has been assigned to all the
# values in the texture atlas.
expected_atlas = torch.tensor([0.5, 0.0, 0.0], dtype=torch.float32)
expected_atlas = expected_atlas[None, None, None, :].expand(2, R, R, -1)
self.assertTrue(torch.allclose(aux.texture_atlas, expected_atlas))
self.assertEquals(len(aux.material_colors.keys()), 1)
self.assertEquals(list(aux.material_colors.keys()), ["material_1"])

def test_load_obj_missing_texture(self):
DATA_DIR = Path(__file__).resolve().parent / "data"
obj_filename = "missing_files_obj/model.obj"
Expand Down

0 comments on commit f5383a7

Please sign in to comment.