Skip to content

Commit

Permalink
Farthest point sampling python naive
Browse files Browse the repository at this point in the history
Summary:
This is a naive python implementation of the iterative farthest point sampling algorithm along with associated simple tests. The C++/CUDA implementations will follow in subsequent diffs.

The algorithm is used to subsample a pointcloud with better coverage of the space of the pointcloud.

The function has not been added to `__init__.py`. I will add this after the full C++/CUDA implementations.

Reviewed By: jcjohnson

Differential Revision: D30285716

fbshipit-source-id: 33f4181041fc652776406bcfd67800a6f0c3dd58
  • Loading branch information
nikhilaravi authored and facebook-github-bot committed Sep 15, 2021
1 parent a0d76a7 commit 3b7d78c
Show file tree
Hide file tree
Showing 5 changed files with 295 additions and 12 deletions.
14 changes: 2 additions & 12 deletions pytorch3d/ops/ball_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.autograd.function import once_differentiable

from .knn import _KNN
from .utils import masked_gather


class _ball_query(Function):
Expand Down Expand Up @@ -123,7 +124,6 @@ def ball_query(
p2 = p2.contiguous()
P1 = p1.shape[1]
P2 = p2.shape[1]
D = p2.shape[2]
N = p1.shape[0]

if lengths1 is None:
Expand All @@ -135,16 +135,6 @@ def ball_query(
dists, idx = _ball_query.apply(p1, p2, lengths1, lengths2, K, radius)

# Gather the neighbors if needed
points_nn = None
if return_nn:
idx_expanded = idx[:, :, :, None].expand(-1, -1, -1, D)
idx_mask = idx_expanded.eq(-1)
idx_new = idx_expanded.clone()
# Replace -1 values with 0 for gather
idx_new[idx_mask] = 0
# Gather points from p2
points_nn = p2[:, :, None].expand(-1, -1, K, -1).gather(1, idx_new)
# Replace padded values
points_nn[idx_mask] = 0.0
points_nn = masked_gather(p2, idx) if return_nn else None

return _KNN(dists=dists, idx=idx, knn=points_nn)
124 changes: 124 additions & 0 deletions pytorch3d/ops/sample_farthest_points.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from random import randint
from typing import Optional, Tuple, Union, List

import torch

from .utils import masked_gather


def sample_farthest_points_naive(
points: torch.Tensor,
lengths: Optional[torch.Tensor] = None,
K: Union[int, List, torch.Tensor] = 50,
random_start_point: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Iterative farthest point sampling algorithm [1] to subsample a set of
K points from a given pointcloud. At each iteration, a point is selected
which has the largest nearest neighbor distance to any of the
already selected points.
Farthest point sampling provides more uniform coverage of the input
point cloud compared to uniform random sampling.
[1] Charles R. Qi et al, "PointNet++: Deep Hierarchical Feature Learning
on Point Sets in a Metric Space", NeurIPS 2017.
Args:
points: (N, P, D) array containing the batch of pointclouds
lengths: (N,) number of points in each pointcloud (to support heterogeneous
batches of pointclouds)
K: samples you want in each sampled point cloud (this is typically << P). If
K is an int then the same number of samples are selected for each
pointcloud in the batch. If K is a tensor is should be length (N,)
giving the number of samples to select for each element in the batch
random_start_point: bool, if True, a random point is selected as the starting
point for iterative sampling.
Returns:
selected_points: (N, K, D), array of selected values from points. If the input
K is a tensor, then the shape will be (N, max(K), D), and padded with
0.0 for batch elements where k_i < max(K).
selected_indices: (N, K) array of selected indices. If the input
K is a tensor, then the shape will be (N, max(K), D), and padded with
-1 for batch elements where k_i < max(K).
"""
N, P, D = points.shape
device = points.device

# Validate inputs
if lengths is None:
lengths = torch.full((N,), P, dtype=torch.int64, device=device)

if lengths.shape[0] != N:
raise ValueError("points and lengths must have same batch dimension.")

# TODO: support providing K as a ratio of the total number of points instead of as an int
if isinstance(K, int):
K = torch.full((N,), K, dtype=torch.int64, device=device)
elif isinstance(K, list):
K = torch.tensor(K, dtype=torch.int64, device=device)

if K.shape[0] != N:
raise ValueError("K and points must have the same batch dimension")

# Find max value of K
max_K = torch.max(K)

# List of selected indices from each batch element
all_sampled_indices = []

for n in range(N):
# Initialize an array for the sampled indices, shape: (max_K,)
sample_idx_batch = torch.full(
(max_K,), fill_value=-1, dtype=torch.int64, device=device
)

# Initialize closest distances to inf, shape: (P,)
# This will be updated at each iteration to track the closest distance of the
# remaining points to any of the selected points
# pyre-fixme[16]: `torch.Tensor` has no attribute new_full.
closest_dists = points.new_full(
(lengths[n],), float("inf"), dtype=torch.float32
)

# Select a random point index and save it as the starting point
selected_idx = randint(0, lengths[n] - 1) if random_start_point else 0
sample_idx_batch[0] = selected_idx

# If the pointcloud has fewer than K points then only iterate over the min
k_n = min(lengths[n], K[n])

# Iteratively select points for a maximum of k_n
for i in range(1, k_n):
# Find the distance between the last selected point
# and all the other points. If a point has already been selected
# it's distance will be 0.0 so it will not be selected again as the max.
dist = points[n, selected_idx, :] - points[n, : lengths[n], :]
dist_to_last_selected = (dist ** 2).sum(-1) # (P - i)

# If closer than currently saved distance to one of the selected
# points, then updated closest_dists
closest_dists = torch.min(dist_to_last_selected, closest_dists) # (P - i)

# The aim is to pick the point that has the largest
# nearest neighbour distance to any of the already selected points
selected_idx = torch.argmax(closest_dists)
sample_idx_batch[i] = selected_idx

# Add the list of points for this batch to the final list
all_sampled_indices.append(sample_idx_batch)

all_sampled_indices = torch.stack(all_sampled_indices, dim=0)

# Gather the points
all_sampled_points = masked_gather(points, all_sampled_indices)

# Return (N, max_K, D) subsampled points and indices
return all_sampled_points, all_sampled_indices
48 changes: 48 additions & 0 deletions pytorch3d/ops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,54 @@
from pytorch3d.structures import Pointclouds


def masked_gather(points: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
"""
Helper function for torch.gather to collect the points at
the given indices in idx where some of the indices might be -1 to
indicate padding. These indices are first replaced with 0.
Then the points are gathered after which the padded values
are set to 0.0.
Args:
points: (N, P, D) float32 tensor of points
idx: (N, K) or (N, P, K) long tensor of indices into points, where
some indices are -1 to indicate padding
Returns:
selected_points: (N, K, D) float32 tensor of points
at the given indices
"""

if len(idx) != len(points):
raise ValueError("points and idx must have the same batch dimension")

N, P, D = points.shape

if idx.ndim == 3:
# Case: KNN, Ball Query where idx is of shape (N, P', K)
# where P' is not necessarily the same as P as the
# points may be gathered from a different pointcloud.
K = idx.shape[2]
# Match dimensions for points and indices
idx_expanded = idx[..., None].expand(-1, -1, -1, D)
points = points[:, :, None, :].expand(-1, -1, K, -1)
elif idx.ndim == 2:
# Farthest point sampling where idx is of shape (N, K)
idx_expanded = idx[..., None].expand(-1, -1, D)
else:
raise ValueError("idx format is not supported %s" % repr(idx.shape))

idx_expanded_mask = idx_expanded.eq(-1)
idx_expanded = idx_expanded.clone()
# Replace -1 values with 0 for gather
idx_expanded[idx_expanded_mask] = 0
# Gather points
selected_points = points.gather(dim=1, index=idx_expanded)
# Replace padded values
selected_points[idx_expanded_mask] = 0.0
return selected_points


def wmean(
x: torch.Tensor,
weight: Optional[torch.Tensor] = None,
Expand Down
10 changes: 10 additions & 0 deletions tests/test_ops_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,13 @@ def test_wmean(self):
mean = oputil.wmean(x, dim=(0, 1), weight=weight, keepdim=False)
mean_gt = np.average(x_np, axis=(0, 1), weights=weight_np)
self.assertClose(mean.cpu().data.numpy(), mean_gt)

def test_masked_gather_errors(self):
idx = torch.randint(0, 10, size=(5, 10, 4, 2))
points = torch.randn(size=(5, 10, 3))
with self.assertRaisesRegex(ValueError, "format is not supported"):
oputil.masked_gather(points, idx)

points = torch.randn(size=(2, 10, 3))
with self.assertRaisesRegex(ValueError, "same batch dimension"):
oputil.masked_gather(points, idx)
111 changes: 111 additions & 0 deletions tests/test_sample_farthest_points.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d.ops.sample_farthest_points import sample_farthest_points_naive
from pytorch3d.ops.utils import masked_gather


class TestFPS(TestCaseMixin, unittest.TestCase):
def test_simple(self):
device = get_random_cuda_device()
# fmt: off
points = torch.tensor(
[
[
[-1.0, -1.0], # noqa: E241, E201
[-1.3, 1.1], # noqa: E241, E201
[ 0.2, -1.1], # noqa: E241, E201
[ 0.0, 0.0], # noqa: E241, E201
[ 1.3, 1.3], # noqa: E241, E201
[ 1.0, 0.5], # noqa: E241, E201
[-1.3, 0.2], # noqa: E241, E201
[ 1.5, -0.5], # noqa: E241, E201
],
[
[-2.2, -2.4], # noqa: E241, E201
[-2.1, 2.0], # noqa: E241, E201
[ 2.2, 2.1], # noqa: E241, E201
[ 2.1, -2.4], # noqa: E241, E201
[ 0.4, -1.0], # noqa: E241, E201
[ 0.3, 0.3], # noqa: E241, E201
[ 1.2, 0.5], # noqa: E241, E201
[ 4.5, 4.5], # noqa: E241, E201
],
],
dtype=torch.float32,
device=device,
)
# fmt: on
expected_inds = torch.tensor([[0, 4], [0, 7]], dtype=torch.int64, device=device)
out_points, out_inds = sample_farthest_points_naive(points, K=2)
self.assertClose(out_inds, expected_inds)

# Gather the points
expected_inds = expected_inds[..., None].expand(-1, -1, points.shape[-1])
self.assertClose(out_points, points.gather(dim=1, index=expected_inds))

# Different number of points sampled for each pointcloud in the batch
expected_inds = torch.tensor(
[[0, 4, 1], [0, 7, -1]], dtype=torch.int64, device=device
)
out_points, out_inds = sample_farthest_points_naive(points, K=[3, 2])
self.assertClose(out_inds, expected_inds)

# Gather the points
expected_points = masked_gather(points, expected_inds)
self.assertClose(out_points, expected_points)

def test_random_heterogeneous(self):
device = get_random_cuda_device()
N, P, D, K = 5, 40, 5, 8
points = torch.randn((N, P, D), device=device)
out_points, out_idxs = sample_farthest_points_naive(points, K=K)
self.assertTrue(out_idxs.min() >= 0)
for n in range(N):
self.assertEqual(out_idxs[n].ne(-1).sum(), K)

lengths = torch.randint(low=1, high=P, size=(N,), device=device)
out_points, out_idxs = sample_farthest_points_naive(points, lengths, K=50)

for n in range(N):
# Check that for heterogeneous batches, the max number of
# selected points is less than the length
self.assertTrue(out_idxs[n].ne(-1).sum() <= lengths[n])
self.assertTrue(out_idxs[n].max() <= lengths[n])

# Check there are no duplicate indices
val_mask = out_idxs[n].ne(-1)
vals, counts = torch.unique(out_idxs[n][val_mask], return_counts=True)
self.assertTrue(counts.le(1).all())

def test_errors(self):
device = get_random_cuda_device()
N, P, D, K = 5, 40, 5, 8
points = torch.randn((N, P, D), device=device)
wrong_batch_dim = torch.randint(low=1, high=K, size=(K,), device=device)

# K has diferent batch dimension to points
with self.assertRaisesRegex(ValueError, "K and points must have"):
sample_farthest_points_naive(points, K=wrong_batch_dim)

# lengths has diferent batch dimension to points
with self.assertRaisesRegex(ValueError, "points and lengths must have"):
sample_farthest_points_naive(points, lengths=wrong_batch_dim, K=K)

def test_random_start(self):
device = get_random_cuda_device()
N, P, D, K = 5, 40, 5, 8
points = torch.randn((N, P, D), device=device)
out_points, out_idxs = sample_farthest_points_naive(
points, K=K, random_start_point=True
)
# Check the first index is not 0 for all batch elements
# when random_start_point = True
self.assertTrue(out_idxs[:, 0].sum() > 0)

0 comments on commit 3b7d78c

Please sign in to comment.