Skip to content

Commit

Permalink
Single function to load meshes from OBJs. join_meshes.
Browse files Browse the repository at this point in the history
Summary:
Create the textures and the Meshes object from OBJ files in a single call.

There is functionality in OBJ files (like normals) which is ignored by this function.

Reviewed By: gkioxari

Differential Revision: D19691699

fbshipit-source-id: e26442ed80ff231b65b17d6c54c9d41e22b4e4a3
  • Loading branch information
bottler authored and facebook-github-bot committed Feb 13, 2020
1 parent 23bb279 commit 8fe65d5
Show file tree
Hide file tree
Showing 8 changed files with 218 additions and 44 deletions.
3 changes: 3 additions & 0 deletions docs/notes/meshes_io.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ tex = Textures(verts_uvs=verts_uvs, faces_uvs=faces_uvs, maps=texture_image)
# Initialise the mesh with textures
meshes = Meshes(verts=[verts], faces=[faces.verts_idx], textures=tex)
```

The `load_objs_as_meshes` function provides this procedure.

## PLY

Ply files are flexible in the way they store additional information, pytorch3d
Expand Down
23 changes: 3 additions & 20 deletions docs/tutorials/render_textured_meshes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
"from skimage.io import imread\n",
"\n",
"# Util function for loading meshes\n",
"from pytorch3d.io import load_obj\n",
"from pytorch3d.io import load_objs_as_meshes\n",
"\n",
"# Data structures and functions for rendering\n",
"from pytorch3d.structures import Meshes, Textures\n",
Expand Down Expand Up @@ -232,25 +232,8 @@
"obj_filename = os.path.join(DATA_DIR, \"cow_mesh/cow.obj\")\n",
"\n",
"# Load obj file\n",
"verts, faces, aux = load_obj(obj_filename)\n",
"faces_idx = faces.verts_idx.to(device)\n",
"verts = verts.to(device)\n",
"\n",
"# Get textures from the outputs of the load_obj function\n",
"# the `aux` variable contains the texture maps and vertex uv coordinates. \n",
"# Refer to the `obj_io.load_obj` function for full API reference. \n",
"# Here we only have one texture map for the whole mesh. \n",
"verts_uvs = aux.verts_uvs[None, ...].to(device) # (N, V, 2)\n",
"faces_uvs = faces.textures_idx[None, ...].to(device) # (N, F, 3)\n",
"tex_maps = aux.texture_images\n",
"texture_image = list(tex_maps.values())[0]\n",
"texture_image = texture_image[None, ...].to(device) # (N, H, W, 3)\n",
"\n",
"# Create a textures object\n",
"tex = Textures(verts_uvs=verts_uvs, faces_uvs=faces_uvs, maps=texture_image)\n",
"\n",
"# Create a meshes object with textures\n",
"mesh = Meshes(verts=[verts], faces=[faces_idx], textures=tex)"
"mesh = load_objs_as_meshes([obj_filename], device=device)\n",
"texture_image=mesh.textures.maps_padded()"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion pytorch3d/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


from .obj_io import load_obj, save_obj
from .obj_io import load_obj, load_objs_as_meshes, save_obj
from .ply_io import load_ply, save_ply

__all__ = [k for k in globals().keys() if not k.startswith("_")]
42 changes: 41 additions & 1 deletion pytorch3d/io/obj_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from fvcore.common.file_io import PathManager
from PIL import Image

from pytorch3d.structures import Meshes, Textures, join_meshes


def _read_image(file_name: str, format=None):
"""
Expand Down Expand Up @@ -90,7 +92,7 @@ def _open_file(f):

def load_obj(f_obj, load_textures=True):
"""
Load a mesh and textures from a .obj and .mtl file.
Load a mesh from a .obj file and optionally textures from a .mtl file.
Currently this handles verts, faces, vertex texture uv coordinates, normals,
texture images and material reflectivity values.
Expand Down Expand Up @@ -208,6 +210,44 @@ def load_obj(f_obj, load_textures=True):
f_obj.close()


def load_objs_as_meshes(files: list, device=None, load_textures: bool = True):
"""
Load meshes from a list of .obj files using the load_obj function, and
return them as a Meshes object. This only works for meshes which have a
single texture image for the whole mesh. See the load_obj function for more
details. material_colors and normals are not stored.
Args:
f: A list of file-like objects (with methods read, readline, tell,
and seek), pathlib paths or strings containing file names.
device: Desired device of returned Meshes. Default:
uses the current device for the default tensor type.
load_textures: Boolean indicating whether material files are loaded
Returns:
New Meshes object.
"""
mesh_list = []
for f_obj in files:
verts, faces, aux = load_obj(f_obj, load_textures=load_textures)
verts = verts.to(device)
tex = None
tex_maps = aux.texture_images
if tex_maps is not None and len(tex_maps) > 0:
verts_uvs = aux.verts_uvs[None, ...].to(device) # (1, V, 2)
faces_uvs = faces.textures_idx[None, ...].to(device) # (1, F, 3)
image = list(tex_maps.values())[0].to(device)[None]
tex = Textures(verts_uvs=verts_uvs, faces_uvs=faces_uvs, maps=image)

mesh = Meshes(
verts=[verts], faces=[faces.verts_idx.to(device)], textures=tex
)
mesh_list.append(mesh)
if len(mesh_list) == 1:
return mesh_list[0]
return join_meshes(mesh_list)


def _parse_face(
line,
material_idx,
Expand Down
2 changes: 1 addition & 1 deletion pytorch3d/structures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

from .meshes import Meshes
from .meshes import Meshes, join_meshes
from .textures import Textures
from .utils import (
list_to_packed,
Expand Down
75 changes: 75 additions & 0 deletions pytorch3d/structures/meshes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

from typing import List
import torch

from pytorch3d import _C
Expand Down Expand Up @@ -1365,3 +1366,77 @@ def extend(self, N: int):
if self.textures is not None:
tex = self.textures.extend(N)
return Meshes(verts=new_verts_list, faces=new_faces_list, textures=tex)


def join_meshes(meshes: List[Meshes], include_textures: bool = True):
"""
Merge multiple Meshes objects, i.e. concatenate the meshes objects. They
must all be on the same device. If include_textures is true, they must all
be compatible, either all or none having textures, and all the Textures
objects having the same members. If include_textures is False, textures are
ignored.
Args:
meshes: list of meshes.
include_textures: (bool) whether to try to join the textures.
Returns:
new Meshes object containing all the meshes from all the inputs.
"""
if isinstance(meshes, Meshes):
# Meshes objects can be iterated and produce single Meshes. We avoid
# letting join_meshes(mesh1, mesh2) silently do the wrong thing.
raise ValueError("Wrong first argument to join_meshes.")
verts = [v for mesh in meshes for v in mesh.verts_list()]
faces = [f for mesh in meshes for f in mesh.faces_list()]
if len(meshes) == 0 or not include_textures:
return Meshes(verts=verts, faces=faces)

if meshes[0].textures is None:
if any(mesh.textures is not None for mesh in meshes):
raise ValueError("Inconsistent textures in join_meshes.")
return Meshes(verts=verts, faces=faces)

if any(mesh.textures is None for mesh in meshes):
raise ValueError("Inconsistent textures in join_meshes.")

# Now we know there are multiple meshes and they have textures to merge.
first = meshes[0].textures
kwargs = {}
if first.maps_padded() is not None:
if any(mesh.textures.maps_padded() is None for mesh in meshes):
raise ValueError("Inconsistent maps_padded in join_meshes.")
maps = [m for mesh in meshes for m in mesh.textures.maps_padded()]
kwargs["maps"] = maps
elif any(mesh.textures.maps_padded() is not None for mesh in meshes):
raise ValueError("Inconsistent maps_padded in join_meshes.")

if first.verts_uvs_padded() is not None:
if any(mesh.textures.verts_uvs_padded() is None for mesh in meshes):
raise ValueError("Inconsistent verts_uvs_padded in join_meshes.")
uvs = [uv for mesh in meshes for uv in mesh.textures.verts_uvs_list()]
V = max(uv.shape[0] for uv in uvs)
kwargs["verts_uvs"] = struct_utils.list_to_padded(uvs, (V, 2), -1)
elif any(mesh.textures.verts_uvs_padded() is not None for mesh in meshes):
raise ValueError("Inconsistent verts_uvs_padded in join_meshes.")

if first.faces_uvs_padded() is not None:
if any(mesh.textures.faces_uvs_padded() is None for mesh in meshes):
raise ValueError("Inconsistent faces_uvs_padded in join_meshes.")
uvs = [uv for mesh in meshes for uv in mesh.textures.faces_uvs_list()]
F = max(uv.shape[0] for uv in uvs)
kwargs["faces_uvs"] = struct_utils.list_to_padded(uvs, (F, 3), -1)
elif any(mesh.textures.faces_uvs_padded() is not None for mesh in meshes):
raise ValueError("Inconsistent faces_uvs_padded in join_meshes.")

if first.verts_rgb_padded() is not None:
if any(mesh.textures.verts_rgb_padded() is None for mesh in meshes):
raise ValueError("Inconsistent verts_rgb_padded in join_meshes.")
rgb = [i for mesh in meshes for i in mesh.textures.verts_rgb_list()]
V = max(i.shape[0] for i in rgb)
kwargs["verts_rgb"] = struct_utils.list_to_padded(rgb, (V, 3))
elif any(mesh.textures.verts_rgb_padded() is not None for mesh in meshes):
raise ValueError("Inconsistent verts_rgb_padded in join_meshes.")

tex = Textures(**kwargs)
return Meshes(verts=verts, faces=faces, textures=tex)
89 changes: 87 additions & 2 deletions tests/test_obj_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
from pathlib import Path
import torch

from pytorch3d.io import load_obj, save_obj
from pytorch3d.io import load_obj, load_objs_as_meshes, save_obj
from pytorch3d.structures import Meshes, Textures, join_meshes

from common_testing import TestCaseMixin

class TestMeshObjIO(unittest.TestCase):

class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
def test_load_obj_simple(self):
obj_file = "\n".join(
[
Expand Down Expand Up @@ -517,6 +520,88 @@ def test_load_obj_missing_mtl_noload(self):
self.assertTrue(aux.material_colors is None)
self.assertTrue(aux.texture_images is None)

def test_join_meshes(self):
"""
Test that join_meshes and load_objs_as_meshes are consistent with single
meshes.
"""

def check_triple(mesh, mesh3):
"""
Verify that mesh3 is three copies of mesh.
"""

def check_item(x, y):
self.assertEqual(x is None, y is None)
if x is not None:
self.assertClose(torch.cat([x, x, x]), y)

check_item(mesh.verts_padded(), mesh3.verts_padded())
check_item(mesh.faces_padded(), mesh3.faces_padded())
if mesh.textures is not None:
check_item(
mesh.textures.maps_padded(), mesh3.textures.maps_padded()
)
check_item(
mesh.textures.faces_uvs_padded(),
mesh3.textures.faces_uvs_padded(),
)
check_item(
mesh.textures.verts_uvs_padded(),
mesh3.textures.verts_uvs_padded(),
)
check_item(
mesh.textures.verts_rgb_padded(),
mesh3.textures.verts_rgb_padded(),
)

DATA_DIR = (
Path(__file__).resolve().parent.parent / "docs/tutorials/data"
)
obj_filename = DATA_DIR / "cow_mesh/cow.obj"

mesh = load_objs_as_meshes([obj_filename])
mesh3 = load_objs_as_meshes([obj_filename, obj_filename, obj_filename])
check_triple(mesh, mesh3)
self.assertTupleEqual(
mesh.textures.maps_padded().shape, (1, 1024, 1024, 3)
)

mesh_notex = load_objs_as_meshes([obj_filename], load_textures=False)
mesh3_notex = load_objs_as_meshes(
[obj_filename, obj_filename, obj_filename], load_textures=False
)
check_triple(mesh_notex, mesh3_notex)
self.assertIsNone(mesh_notex.textures)

verts = torch.randn((4, 3), dtype=torch.float32)
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
vert_tex = torch.tensor(
[[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.float32
)
tex = Textures(verts_rgb=vert_tex[None, :])
mesh_rgb = Meshes(verts=[verts], faces=[faces], textures=tex)
mesh_rgb3 = join_meshes([mesh_rgb, mesh_rgb, mesh_rgb])
check_triple(mesh_rgb, mesh_rgb3)

teapot_obj = DATA_DIR / "teapot.obj"
mesh_teapot = load_objs_as_meshes([teapot_obj])
teapot_verts, teapot_faces = mesh_teapot.get_mesh_verts_faces(0)
mix_mesh = load_objs_as_meshes(
[obj_filename, teapot_obj], load_textures=False
)
self.assertEqual(len(mix_mesh), 2)
self.assertClose(mix_mesh.verts_list()[0], mesh.verts_list()[0])
self.assertClose(mix_mesh.faces_list()[0], mesh.faces_list()[0])
self.assertClose(mix_mesh.verts_list()[1], teapot_verts)
self.assertClose(mix_mesh.faces_list()[1], teapot_faces)

cow3_tea = join_meshes([mesh3, mesh_teapot], include_textures=False)
self.assertEqual(len(cow3_tea), 4)
check_triple(mesh_notex, cow3_tea[:3])
self.assertClose(cow3_tea.verts_list()[3], mesh_teapot.verts_list()[0])
self.assertClose(cow3_tea.faces_list()[3], mesh_teapot.faces_list()[0])

@staticmethod
def save_obj_with_init(V: int, F: int):
verts_list = torch.tensor(V * [[0.11, 0.22, 0.33]]).view(-1, 3)
Expand Down
26 changes: 7 additions & 19 deletions tests/test_rendering_meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
from PIL import Image

from pytorch3d.io import load_obj
from pytorch3d.io import load_objs_as_meshes
from pytorch3d.renderer.cameras import (
OpenGLPerspectiveCameras,
look_at_view_transform,
Expand Down Expand Up @@ -274,21 +274,7 @@ def test_texture_map(self):
obj_filename = DATA_DIR / "cow_mesh/cow.obj"

# Load mesh + texture
verts, faces, aux = load_obj(obj_filename)
faces_idx = faces.verts_idx.to(device)
verts = verts.to(device)
texture_uvs = aux.verts_uvs
materials = aux.material_colors
tex_maps = aux.texture_images

# tex_maps is a dictionary of material names as keys and texture images
# as values. Only need the images for this example.
textures = Textures(
maps=list(tex_maps.values()),
faces_uvs=faces.textures_idx.to(torch.int64).to(device)[None, :],
verts_uvs=texture_uvs.to(torch.float32).to(device)[None, :],
)
mesh = Meshes(verts=[verts], faces=[faces_idx], textures=textures)
mesh = load_objs_as_meshes([obj_filename], device=device)

# Init rasterizer settings
R, T = look_at_view_transform(2.7, 10, 20)
Expand Down Expand Up @@ -333,9 +319,11 @@ def test_texture_map(self):
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))

# Check grad exists
verts = verts.clone()
[verts] = mesh.verts_list()
verts.requires_grad = True
mesh = Meshes(verts=[verts], faces=[faces_idx], textures=textures)
images = renderer(mesh)
mesh2 = Meshes(
verts=[verts], faces=mesh.faces_list(), textures=mesh.textures
)
images = renderer(mesh2)
images[0, ...].sum().backward()
self.assertIsNotNone(verts.grad)

0 comments on commit 8fe65d5

Please sign in to comment.