-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Farthest point sampling python naive
Summary: This is a naive python implementation of the iterative farthest point sampling algorithm along with associated simple tests. The C++/CUDA implementations will follow in subsequent diffs. The algorithm is used to subsample a pointcloud with better coverage of the space of the pointcloud. The function has not been added to `__init__.py`. I will add this after the full C++/CUDA implementations. Reviewed By: jcjohnson Differential Revision: D30285716 fbshipit-source-id: 33f4181041fc652776406bcfd67800a6f0c3dd58
- Loading branch information
1 parent
a0d76a7
commit 3b7d78c
Showing
5 changed files
with
295 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from random import randint | ||
from typing import Optional, Tuple, Union, List | ||
|
||
import torch | ||
|
||
from .utils import masked_gather | ||
|
||
|
||
def sample_farthest_points_naive( | ||
points: torch.Tensor, | ||
lengths: Optional[torch.Tensor] = None, | ||
K: Union[int, List, torch.Tensor] = 50, | ||
random_start_point: bool = False, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
""" | ||
Iterative farthest point sampling algorithm [1] to subsample a set of | ||
K points from a given pointcloud. At each iteration, a point is selected | ||
which has the largest nearest neighbor distance to any of the | ||
already selected points. | ||
Farthest point sampling provides more uniform coverage of the input | ||
point cloud compared to uniform random sampling. | ||
[1] Charles R. Qi et al, "PointNet++: Deep Hierarchical Feature Learning | ||
on Point Sets in a Metric Space", NeurIPS 2017. | ||
Args: | ||
points: (N, P, D) array containing the batch of pointclouds | ||
lengths: (N,) number of points in each pointcloud (to support heterogeneous | ||
batches of pointclouds) | ||
K: samples you want in each sampled point cloud (this is typically << P). If | ||
K is an int then the same number of samples are selected for each | ||
pointcloud in the batch. If K is a tensor is should be length (N,) | ||
giving the number of samples to select for each element in the batch | ||
random_start_point: bool, if True, a random point is selected as the starting | ||
point for iterative sampling. | ||
Returns: | ||
selected_points: (N, K, D), array of selected values from points. If the input | ||
K is a tensor, then the shape will be (N, max(K), D), and padded with | ||
0.0 for batch elements where k_i < max(K). | ||
selected_indices: (N, K) array of selected indices. If the input | ||
K is a tensor, then the shape will be (N, max(K), D), and padded with | ||
-1 for batch elements where k_i < max(K). | ||
""" | ||
N, P, D = points.shape | ||
device = points.device | ||
|
||
# Validate inputs | ||
if lengths is None: | ||
lengths = torch.full((N,), P, dtype=torch.int64, device=device) | ||
|
||
if lengths.shape[0] != N: | ||
raise ValueError("points and lengths must have same batch dimension.") | ||
|
||
# TODO: support providing K as a ratio of the total number of points instead of as an int | ||
if isinstance(K, int): | ||
K = torch.full((N,), K, dtype=torch.int64, device=device) | ||
elif isinstance(K, list): | ||
K = torch.tensor(K, dtype=torch.int64, device=device) | ||
|
||
if K.shape[0] != N: | ||
raise ValueError("K and points must have the same batch dimension") | ||
|
||
# Find max value of K | ||
max_K = torch.max(K) | ||
|
||
# List of selected indices from each batch element | ||
all_sampled_indices = [] | ||
|
||
for n in range(N): | ||
# Initialize an array for the sampled indices, shape: (max_K,) | ||
sample_idx_batch = torch.full( | ||
(max_K,), fill_value=-1, dtype=torch.int64, device=device | ||
) | ||
|
||
# Initialize closest distances to inf, shape: (P,) | ||
# This will be updated at each iteration to track the closest distance of the | ||
# remaining points to any of the selected points | ||
# pyre-fixme[16]: `torch.Tensor` has no attribute new_full. | ||
closest_dists = points.new_full( | ||
(lengths[n],), float("inf"), dtype=torch.float32 | ||
) | ||
|
||
# Select a random point index and save it as the starting point | ||
selected_idx = randint(0, lengths[n] - 1) if random_start_point else 0 | ||
sample_idx_batch[0] = selected_idx | ||
|
||
# If the pointcloud has fewer than K points then only iterate over the min | ||
k_n = min(lengths[n], K[n]) | ||
|
||
# Iteratively select points for a maximum of k_n | ||
for i in range(1, k_n): | ||
# Find the distance between the last selected point | ||
# and all the other points. If a point has already been selected | ||
# it's distance will be 0.0 so it will not be selected again as the max. | ||
dist = points[n, selected_idx, :] - points[n, : lengths[n], :] | ||
dist_to_last_selected = (dist ** 2).sum(-1) # (P - i) | ||
|
||
# If closer than currently saved distance to one of the selected | ||
# points, then updated closest_dists | ||
closest_dists = torch.min(dist_to_last_selected, closest_dists) # (P - i) | ||
|
||
# The aim is to pick the point that has the largest | ||
# nearest neighbour distance to any of the already selected points | ||
selected_idx = torch.argmax(closest_dists) | ||
sample_idx_batch[i] = selected_idx | ||
|
||
# Add the list of points for this batch to the final list | ||
all_sampled_indices.append(sample_idx_batch) | ||
|
||
all_sampled_indices = torch.stack(all_sampled_indices, dim=0) | ||
|
||
# Gather the points | ||
all_sampled_points = masked_gather(points, all_sampled_indices) | ||
|
||
# Return (N, max_K, D) subsampled points and indices | ||
return all_sampled_points, all_sampled_indices |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
|
||
import torch | ||
from common_testing import TestCaseMixin, get_random_cuda_device | ||
from pytorch3d.ops.sample_farthest_points import sample_farthest_points_naive | ||
from pytorch3d.ops.utils import masked_gather | ||
|
||
|
||
class TestFPS(TestCaseMixin, unittest.TestCase): | ||
def test_simple(self): | ||
device = get_random_cuda_device() | ||
# fmt: off | ||
points = torch.tensor( | ||
[ | ||
[ | ||
[-1.0, -1.0], # noqa: E241, E201 | ||
[-1.3, 1.1], # noqa: E241, E201 | ||
[ 0.2, -1.1], # noqa: E241, E201 | ||
[ 0.0, 0.0], # noqa: E241, E201 | ||
[ 1.3, 1.3], # noqa: E241, E201 | ||
[ 1.0, 0.5], # noqa: E241, E201 | ||
[-1.3, 0.2], # noqa: E241, E201 | ||
[ 1.5, -0.5], # noqa: E241, E201 | ||
], | ||
[ | ||
[-2.2, -2.4], # noqa: E241, E201 | ||
[-2.1, 2.0], # noqa: E241, E201 | ||
[ 2.2, 2.1], # noqa: E241, E201 | ||
[ 2.1, -2.4], # noqa: E241, E201 | ||
[ 0.4, -1.0], # noqa: E241, E201 | ||
[ 0.3, 0.3], # noqa: E241, E201 | ||
[ 1.2, 0.5], # noqa: E241, E201 | ||
[ 4.5, 4.5], # noqa: E241, E201 | ||
], | ||
], | ||
dtype=torch.float32, | ||
device=device, | ||
) | ||
# fmt: on | ||
expected_inds = torch.tensor([[0, 4], [0, 7]], dtype=torch.int64, device=device) | ||
out_points, out_inds = sample_farthest_points_naive(points, K=2) | ||
self.assertClose(out_inds, expected_inds) | ||
|
||
# Gather the points | ||
expected_inds = expected_inds[..., None].expand(-1, -1, points.shape[-1]) | ||
self.assertClose(out_points, points.gather(dim=1, index=expected_inds)) | ||
|
||
# Different number of points sampled for each pointcloud in the batch | ||
expected_inds = torch.tensor( | ||
[[0, 4, 1], [0, 7, -1]], dtype=torch.int64, device=device | ||
) | ||
out_points, out_inds = sample_farthest_points_naive(points, K=[3, 2]) | ||
self.assertClose(out_inds, expected_inds) | ||
|
||
# Gather the points | ||
expected_points = masked_gather(points, expected_inds) | ||
self.assertClose(out_points, expected_points) | ||
|
||
def test_random_heterogeneous(self): | ||
device = get_random_cuda_device() | ||
N, P, D, K = 5, 40, 5, 8 | ||
points = torch.randn((N, P, D), device=device) | ||
out_points, out_idxs = sample_farthest_points_naive(points, K=K) | ||
self.assertTrue(out_idxs.min() >= 0) | ||
for n in range(N): | ||
self.assertEqual(out_idxs[n].ne(-1).sum(), K) | ||
|
||
lengths = torch.randint(low=1, high=P, size=(N,), device=device) | ||
out_points, out_idxs = sample_farthest_points_naive(points, lengths, K=50) | ||
|
||
for n in range(N): | ||
# Check that for heterogeneous batches, the max number of | ||
# selected points is less than the length | ||
self.assertTrue(out_idxs[n].ne(-1).sum() <= lengths[n]) | ||
self.assertTrue(out_idxs[n].max() <= lengths[n]) | ||
|
||
# Check there are no duplicate indices | ||
val_mask = out_idxs[n].ne(-1) | ||
vals, counts = torch.unique(out_idxs[n][val_mask], return_counts=True) | ||
self.assertTrue(counts.le(1).all()) | ||
|
||
def test_errors(self): | ||
device = get_random_cuda_device() | ||
N, P, D, K = 5, 40, 5, 8 | ||
points = torch.randn((N, P, D), device=device) | ||
wrong_batch_dim = torch.randint(low=1, high=K, size=(K,), device=device) | ||
|
||
# K has diferent batch dimension to points | ||
with self.assertRaisesRegex(ValueError, "K and points must have"): | ||
sample_farthest_points_naive(points, K=wrong_batch_dim) | ||
|
||
# lengths has diferent batch dimension to points | ||
with self.assertRaisesRegex(ValueError, "points and lengths must have"): | ||
sample_farthest_points_naive(points, lengths=wrong_batch_dim, K=K) | ||
|
||
def test_random_start(self): | ||
device = get_random_cuda_device() | ||
N, P, D, K = 5, 40, 5, 8 | ||
points = torch.randn((N, P, D), device=device) | ||
out_points, out_idxs = sample_farthest_points_naive( | ||
points, K=K, random_start_point=True | ||
) | ||
# Check the first index is not 0 for all batch elements | ||
# when random_start_point = True | ||
self.assertTrue(out_idxs[:, 0].sum() > 0) |