Skip to content

Commit

Permalink
subsample pointclouds
Browse files Browse the repository at this point in the history
Summary: New function to randomly subsample Pointclouds to a maximum size.

Reviewed By: nikhilaravi

Differential Revision: D30936533

fbshipit-source-id: 789eb5004b6a233034ec1c500f20f2d507a303ff
  • Loading branch information
bottler authored and facebook-github-bot committed Oct 2, 2021
1 parent ee2b2fe commit 4281df1
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 21 deletions.
52 changes: 52 additions & 0 deletions pytorch3d/structures/pointclouds.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from itertools import zip_longest
from typing import Sequence, Union

import numpy as np
import torch

from ..common.types import Device, make_device
Expand Down Expand Up @@ -841,6 +845,54 @@ def offset(self, offsets_packed):
new_clouds = self.clone()
return new_clouds.offset_(offsets_packed)

def subsample(self, max_points: Union[int, Sequence[int]]) -> "Pointclouds":
"""
Subsample each cloud so that it has at most max_points points.
Args:
max_points: maximum number of points in each cloud.
Returns:
new Pointclouds object, or self if nothing to be done.
"""
if isinstance(max_points, int):
max_points = [max_points] * len(self)
elif len(max_points) != len(self):
raise ValueError("wrong number of max_points supplied")
if all(
int(n_points) <= int(max_)
for n_points, max_ in zip(self.num_points_per_cloud(), max_points)
):
return self

points_list = []
features_list = []
normals_list = []
for max_, n_points, points, features, normals in zip_longest(
map(int, max_points),
map(int, self.num_points_per_cloud()),
self.points_list(),
self.features_list() or (),
self.normals_list() or (),
):
if n_points > max_:
keep_np = np.random.choice(n_points, max_, replace=False)
keep = torch.tensor(keep_np).to(points.device)
points = points[keep]
if features is not None:
features = features[keep]
if normals is not None:
normals = normals[keep]
points_list.append(points)
features_list.append(features)
normals_list.append(normals)

return Pointclouds(
points=points_list,
normals=self.normals_list() and normals_list,
features=self.features_list() and features_list,
)

def scale_(self, scale):
"""
Multiply the coordinates of this object by a scalar value.
Expand Down
22 changes: 1 addition & 21 deletions pytorch3d/vis/plotly_vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import warnings
from typing import Dict, List, NamedTuple, Optional, Tuple, Union

import numpy as np
import plotly.graph_objects as go
import torch
from plotly.subplots import make_subplots
Expand Down Expand Up @@ -644,31 +643,12 @@ def _add_pointcloud_trace(
max_points_per_pointcloud: the number of points to render, which are randomly sampled.
marker_size: the size of the rendered points
"""
pointclouds = pointclouds.detach().cpu()
pointclouds = pointclouds.detach().cpu().subsample(max_points_per_pointcloud)
verts = pointclouds.points_packed()
features = pointclouds.features_packed()

indices = None
if pointclouds.num_points_per_cloud().max() > max_points_per_pointcloud:
start_index = 0
index_list = []
for num_points in pointclouds.num_points_per_cloud():
if num_points > max_points_per_pointcloud:
indices_cloud = np.random.choice(
num_points, max_points_per_pointcloud, replace=False
)
index_list.append(start_index + indices_cloud)
else:
index_list.append(start_index + np.arange(num_points))
start_index += num_points
indices = np.concatenate(index_list)
verts = verts[indices]

color = None
if features is not None:
if indices is not None:
# Only select features if we selected vertices above
features = features[indices]
if features.shape[1] == 4: # rgba
template = "rgb(%d, %d, %d, %f)"
rgb = (features[:, :3].clamp(0.0, 1.0) * 255).int()
Expand Down
39 changes: 39 additions & 0 deletions tests/test_pointclouds.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,6 +1057,45 @@ def test_estimate_normals(self):
clouds.normals_packed(), torch.cat(normals_est_list, dim=0)
)

def test_subsample(self):
lengths = [4, 5, 13, 3]
points = [torch.rand(length, 3) for length in lengths]
features = [torch.rand(length, 5) for length in lengths]
normals = [torch.rand(length, 3) for length in lengths]

pcl1 = Pointclouds(points=points).cuda()
self.assertIs(pcl1, pcl1.subsample(13))
self.assertIs(pcl1, pcl1.subsample([6, 13, 13, 13]))

lengths_max_4 = torch.tensor([4, 4, 4, 3]).cuda()
for with_normals, with_features in itertools.product([True, False], repeat=2):
with self.subTest(f"{with_normals} {with_features}"):
pcl = Pointclouds(
points=points,
normals=normals if with_normals else None,
features=features if with_features else None,
)
pcl_copy = pcl.subsample(max_points=4)
for length, points_ in zip(lengths_max_4, pcl_copy.points_list()):
self.assertEqual(points_.shape, (length, 3))
if with_normals:
for length, normals_ in zip(lengths_max_4, pcl_copy.normals_list()):
self.assertEqual(normals_.shape, (length, 3))
else:
self.assertIsNone(pcl_copy.normals_list())
if with_features:
for length, features_ in zip(
lengths_max_4, pcl_copy.features_list()
):
self.assertEqual(features_.shape, (length, 5))
else:
self.assertIsNone(pcl_copy.features_list())

pcl2 = Pointclouds(points=points)
pcl_copy2 = pcl2.subsample(lengths_max_4)
for length, points_ in zip(lengths_max_4, pcl_copy2.points_list()):
self.assertEqual(points_.shape, (length, 3))

@staticmethod
def compute_packed_with_init(
num_clouds: int = 10, max_p: int = 100, features: int = 300
Expand Down

0 comments on commit 4281df1

Please sign in to comment.