Algorithm to find nearest neighbors #9813
As I mentioned in the StackOverflow answer, the only approach I know of in JAX is to use a brute force search over all neighbors. For example: import numpy as np
from scipy.spatial import cKDTree
from functools import partial
from jax import jit
import jax.numpy as jnp
from jax import random
key = random.PRNGKey(0)
num_points = 1000
num_dimensions = 3
X = np.random.rand(num_points, num_dimensions)
@partial(jit, static_argnums=1)
def nearest_neighbors_jax(X, k):
distance_matrix = jnp.sum((X[:, None, :] - X[None, :, :]) ** 2, axis=-1)
return jnp.argsort(distance_matrix, axis=-1)[:, :k]
indices_jax = nearest_neighbors_jax(X, 5)
_, indices_scipy = cKDTree(X).query(X, k=5)
np.testing.assert_array_equal(indices_jax, indices_scipy) This is not going to be competitive with a tree-based approach as the number of points grows large (brute force search is O[N^2] while KD Tree is O[N log N] for N points). I'm not aware of any tree-based neighbor searches that are compatible with JAX; if you can find one, that would likely be the best approach. If not, then sticking with cKDTree is probably your best bet. I'm sorry my StackOverflow answer was not helpful; to be honest I found it very confusing to ascertain exactly what you were asking. |
Is there any possibility that some KD-tree algorithm for the nearest neighbors algorithm will be implemented in JAX (seems to not be possible without a built-in function)? This seems really important for a lot of models. I'm currently using JAX for graph neural networks and point modeling of point clouds, and K-nearest neighbors has come up a lot, meaning I've had to use PyTorch for performance-critical tasks. |
There is an approximate nearest neighbor implementation here but it seems to be designed for high dimensional vectors, not 3D... Also there is |
Though this is an oldish thread, I thought I'd post my approximate tree-based nearest neighbours implementation for others to use. It is for 2D points, but could easily be expanded to nD. If someone did that, they would be welcome to post it below. import dataclasses
import warnings
from functools import partial
from typing import NamedTuple, Tuple
import jax
import jax.numpy as jnp
import numpy as np
class GridTree(NamedTuple):
grid: jax.Array # [num_grids, max_points_per_cell]
points: jax.Array # [n_points, 2]
extent: Tuple[jax.Array, jax.Array, jax.Array, jax.Array] # [4] (min_x, max_x, min_y, max_y)
class ApproximateTreeNN:
Approximate tree for nearest neighbor search on 2D box.
A tree structure is used to find the k nearest neighbors to a given point in 2D space, by constructing a grid of
shape (n_grid, n_grid) where,
n_grid = int(sqrt(n / average_points_per_cell)), where n is the number of points.
The memory usage goes as O(n * ( 2 + kappa / sqrt(average_points_per_cell))).
The tree build time goes as O(n).
The tree query time goes as O((average_points_per_cell + kappa * sqrt(average_points_per_cell)) + k log k)
Accuracy generally increases with more points in the tree. Accuracy decreasses with larger `k` queries due to cell
edge effects.
average_points_per_cell: Average number of points per cell in the grid.
kappa: how many sigmas above the expected number of points per cell to allow.
average_points_per_cell: int = 16
kappa: float = 5.0
def _point_to_cell(self, point: jax.Array, n_grid: int,
extent: Tuple[jax.Array, jax.Array, jax.Array, jax.Array]) -> Tuple[jax.Array, jax.Array]:
x_min, x_max, y_min, y_max = extent
# x = point[0]
# y = point[1]
# cell_x = jnp.floor(x * n_grid).astype(int)
# cell_y = jnp.floor(y * n_grid).astype(int)
# return cell_x, cell_y
cell_x = jnp.clip(jnp.floor((point[0] - x_min) / (x_max - x_min) * n_grid), 0, n_grid - 1).astype(int)
cell_y = jnp.clip(jnp.floor((point[1] - y_min) / (y_max - y_min) * n_grid), 0, n_grid - 1).astype(int)
return cell_x, cell_y
def _grid_to_idx(self, cell_x: jax.Array, cell_y: jax.Array, n_grid: int) -> jax.Array:
"""Maps (cell_x, cell_y) to grid_idx."""
return cell_y * n_grid + cell_x
def _idx_to_grid(self, grid_idx: jax.Array, n_grid: int) -> Tuple[jax.Array, jax.Array]:
"""Maps grid_idx back to (cell_x, cell_y)."""
cell_x = grid_idx % n_grid
cell_y = grid_idx // n_grid
return cell_x, cell_y
def build_tree(self, points: jax.Array) -> GridTree:
Builds the tree structure given the points in the space [a,b]x[c,d].
points (jax.numpy.ndarray): Array of points with shape (n_points, 2).
GridTree: A named tuple containing the grid, grid size, max points per cell, and the original points.
n_points = points.shape[0]
if n_points == 0:
raise ValueError("No points provided to build the tree.")
n_grid = int(np.sqrt(n_points / self.average_points_per_cell))
if n_grid < 1:
warnings.warn("Number of points is too small to meet desired average points per cell.")
n_grid = 1
num_cells = n_grid * n_grid
max_points_per_cell = int(
n_points / num_cells + self.kappa * np.sqrt(n_points / num_cells)
if max_points_per_cell < 1:
raise ValueError("max_points_per_cell must be at least 1.")
grid = -1 * jnp.ones((num_cells, max_points_per_cell), dtype=int)
storage_indices = jnp.zeros(num_cells, dtype=int) # To track where to store the next point in each grid
points_min = jnp.min(points, axis=0)
points_max = jnp.max(points, axis=0)
extent = (points_min[0], points_max[0], points_min[1], points_max[1])
def assign_point(i, state):
grid, storage_indices = state
point = points[i]
cell_x, cell_y = self._point_to_cell(point, n_grid, extent)
grid_idx = self._grid_to_idx(cell_x, cell_y, n_grid)
storage_index = storage_indices[grid_idx]
grid =[grid_idx, storage_index].set(i)
storage_indices =[grid_idx].set((storage_index + 1) % max_points_per_cell)
return grid, storage_indices
grid, storage_indices = jax.lax.fori_loop(0, n_points, assign_point, (grid, storage_indices))
# Some cells may not have any points, thus test points that fall within that cell will have no neighbors.
# We can solve this and improve edge effects by filling up all -1 with random points from neighboring cells.
def body(state):
i, grid, storage_indices = state
# Get a random neighbour for each unfilled point in each cell
G, P = jnp.meshgrid(jnp.arange(np.shape(grid)[0]), jnp.arange(np.shape(grid)[1]), indexing='ij')
cell_x, cell_y = self._idx_to_grid(G, n_grid) # [num_cells, max_points_per_cell]
neighbour_inc = jax.random.randint(
jax.random.PRNGKey(42), np.shape(G) + (2,),
-1, 2
) # [num_cells, max_points_per_cell, 2]
neighbour_x = jnp.clip(cell_x + neighbour_inc[:, :, 0], 0, n_grid - 1)
neighbour_y = jnp.clip(cell_y + neighbour_inc[:, :, 1], 0, n_grid - 1)
neighbour_grid_idx = self._grid_to_idx(neighbour_x, neighbour_y, n_grid) # [num_cells, max_points_per_cell]
random_select = jax.random.randint(jax.random.PRNGKey(42), np.shape(G), 0,
storage_indices[:, None]) # [num_cells, max_points_per_cell]
random_neighbour = grid[neighbour_grid_idx, random_select] # [num_cells, max_points_per_cell]
# Check that random neighbour is not already in the cell it would go to
@partial(jax.vmap, in_axes=(0, 0))
@partial(jax.vmap, in_axes=(0, 0))
def check_cell(i, j):
return jnp.logical_not(jnp.any(grid[i] == random_neighbour[i, j]))
replace = (grid == -1) & check_cell(G, P)
grid = jnp.where(replace, random_neighbour, grid)
grid = jnp.sort(grid, axis=1, descending=True)
storage_indices = jnp.sum(grid != -1, axis=1)
return i + 1, grid, storage_indices
def cond(state):
# Until all -1 are replaced
i, grid, storage_indices = state
return jnp.any(grid == -1) & (i < 10)
_, grid, _ = jax.lax.while_loop(cond, body, (0, grid, storage_indices))
return GridTree(grid=grid, points=points, extent=extent)
def query(self, tree: GridTree, test_point: jax.Array, k: int = 1) -> Tuple[jax.Array, jax.Array]:
Queries the tree structure to find the k nearest neighbors to the test point.
tree (GridTree): The tree structure built by the `build_tree` method.
test_point (jax.numpy.ndarray): A point in [a,b]x[c,d] with shape (2,).
k (int): The number of nearest neighbors to find.
distances (jax.numpy.ndarray): Distances to the k nearest neighbors.
indices (jax.numpy.ndarray): Indices of the k nearest neighbors.
n_grid = int(np.sqrt(np.shape(tree.grid)[0]))
cell_x, cell_y = self._point_to_cell(test_point, n_grid, tree.extent)
grid_idx = self._grid_to_idx(cell_x, cell_y, n_grid)
point_indices = tree.grid[grid_idx] # [max_points_per_cell]
points_in_cell = tree.points[point_indices] # [max_points_per_cell, 2]
# Gather valid indices using jax.numpy.where and take them instead of boolean indexing
valid_mask = point_indices >= 0
distances = jnp.linalg.norm(points_in_cell - test_point, axis=1) # [max_points_per_cell]
neg_distances = jnp.where(valid_mask, -distances, -jnp.inf)
top_k_neg_distances, top_k_indices_within_cell = jax.lax.top_k(neg_distances, k)
top_k_distances = -top_k_neg_distances
# Return the actual distances and the corresponding indices in the original points array
return top_k_distances, point_indices[top_k_indices_within_cell] Here's a performance test against brute-force. def brute_force_nearest_neighbors(points: jnp.ndarray, test_point: jnp.ndarray, k: int) -> Tuple[
jnp.ndarray, jnp.ndarray]:
A brute-force approach to find the k nearest neighbors to a test point.
points (jax.numpy.ndarray): Array of points with shape (n_points, 2).
test_point (jax.numpy.ndarray): A point in [0,1]^2 with shape (2,).
k (int): The number of nearest neighbors to find.
distances (jax.numpy.ndarray): Distances to the k nearest neighbors.
indices (jax.numpy.ndarray): Indices of the k nearest neighbors.
distances = jnp.linalg.norm(points - test_point, axis=1)
top_k_neg_distances, top_k_indices = jax.lax.top_k(-distances, k)
return -top_k_neg_distances, top_k_indices
@pytest.mark.parametrize("k", [1, 2, 3])
@pytest.mark.parametrize("m", [1, 100, 1000])
@pytest.mark.parametrize("n", [1000, 100000])
@pytest.mark.parametrize("average_points_per_cell", [9, 16, 25])
@pytest.mark.parametrize("kappa", [1., 5., 10.])
def test_performance(n: int, k: int, m: int, average_points_per_cell: int, kappa: float):
Performance test comparing ApproximateTree with brute-force approach for varying number of points.
# Generate uniformly distributed points
key = random.PRNGKey(0)
points = random.uniform(key, (n, 2))
# Test now with vmap test_points
test_points = random.uniform(jax.random.PRNGKey(1), (m, 2))
# ApproximateTree method
approx_tree = ApproximateTreeNN(average_points_per_cell=average_points_per_cell, kappa=kappa)
build_tree = jax.jit(approx_tree.build_tree).lower(points).compile()
t0 = time.time()
tree = build_tree(points)
tree_build_time = time.time() - t0
query = jax.jit(jax.vmap(lambda test_point: approx_tree.query(tree, test_point, k))).lower(test_points).compile()
t0 = time.time()
distances_approx, indices_approx = query(test_points)
approx_time = time.time() - t0
# Brute-force method
brute_force_nearest_neighbors_vmap = jax.jit(
jax.vmap(lambda test_point: brute_force_nearest_neighbors(points, test_point, k))
t0 = time.time()
distances_brute, indices_brute = brute_force_nearest_neighbors_vmap(test_points)
brute_time = time.time() - t0
speedup = brute_time / approx_time
index_error_rate = np.mean(indices_brute != indices_approx)
mean_abs_error = np.mean(np.abs(distances_brute - distances_approx))
rmse = np.sqrt(np.mean((distances_brute - distances_approx) ** 2))
f"n={n}, m={m}, k={k}, avg_points_per_cell={average_points_per_cell}, kappa={kappa}:\n"
f"\tTree build time: {tree_build_time:.5f}s\n"
f"\tApproximateTree time: {approx_time:.5f}s\n"
f"\tBrute-force time: {brute_time:.5f}s\n"
f"\tSpeedup: {speedup:.2f}\n"
f"\tIndex error rate: {index_error_rate:.2f}\n"
f"\tMean absolute error: {mean_abs_error:.5f}\n"
f"\tRMSE: {rmse:.5f}\n"
) |
I have been trying to wrap scipy's KDTree with pure_callback, but I suspect that jax passes an abstract array through the function, because I get the following error: File "/home/danj/.pyenv/versions/3.12.7/envs/jax/lib/python3.12/site-packages/jax/_src/interpreters/", line 2793, in _wrapped_callback
File "/home/danj/.pyenv/versions/3.12.7/envs/jax/lib/python3.12/site-packages/jax/_src/", line 228, in _callback
File "/home/danj/.pyenv/versions/3.12.7/envs/jax/lib/python3.12/site-packages/jax/_src/", line 89, in pure_callback_impl
File "/home/danj/.pyenv/versions/3.12.7/envs/jax/lib/python3.12/site-packages/jax/_src/", line 64, in __call__
File "<ipython-input-41-5e037bac060e>", line 1, in <lambda>
File "/home/danj/.pyenv/versions/3.12.7/envs/jax/lib/python3.12/site-packages/scipy/spatial/", line 475, in query
File "/home/danj/.pyenv/versions/3.12.7/envs/jax/lib/python3.12/site-packages/jax/_src/", line 286, in __len__
TypeError: len() of unsized object Here is the function now: def scipy_k_nearest_senders(r: jax.Array, s: jax.Array, k: int):
d_shape = jax.ShapeDtypeStruct((r.shape[0], k), jnp.float32)
idx_shape = jax.ShapeDtypeStruct((r.shape[0], k), jnp.int32)
f = lambda r, s, k: KDTree(s).query(r, k)
d, idx = jax.pure_callback(f, (d_shape, idx_shape), r, s, k)
return idx.flatten(), d.flatten() |
def kd_tree_nn(points: jax.Array, test_points: jax.Array, k: int = 1) -> Tuple[jax.Array, jax.Array]:
Uses a KD-tree to find the k nearest neighbors to a test point in 3D space.
points: [n, d] Array of points.
test_points: [m, d] points to query
k: The number of nearest neighbors to find.
distances: [m, k] Distances to the k nearest neighbors.
indices: [m, k] Indices of the k nearest neighbors.
m, d = np.shape(test_points)
k = int(k)
args = (
distance_shape_dtype = jax.ShapeDtypeStruct(
shape=(m, k),
index_shape_dtype = jax.ShapeDtypeStruct(
shape=(m, k),
return jax.pure_callback(_kd_tree_nn_host, (distance_shape_dtype, index_shape_dtype), *args)
def _kd_tree_nn_host(points: jax.Array, test_points: jax.Array, k: int) -> Tuple[np.ndarray, np.ndarray]:
Uses a KD-tree to find the k nearest neighbors to a test point in 3D space.
points: [n, d] Array of points.
test_points: [m, d] points to query
k: The number of nearest neighbors to find.
distances: [m, k] Distances to the k nearest neighbors.
indices: [m, k] Indices of the k nearest neighbors.
points, test_points =, (points, test_points))
k = int(k)
tree = KDTree(points, compact_nodes=False, balanced_tree=False)
if k == 1:
distances, indices = tree.query(test_points, k=[1]) # unsqueeze k
distances, indices = tree.query(test_points, k=k)
return distances, indices.astype(int_type) |
I have written a question on stack overflow in which the code find nearest spheres to each sphere. I have tried to implement the code by JAX, but I couldn't. Two methods are prepared for finding near points, SciPy cKDTree and NumPy. But I couldn't write a performant code or algorithm by these methods using JAX. As Jakedvp answered on the question, cKDTree is not compatible by JAX and we must try to find or write an algorithm for this. Perhaps, this goal could be achieved by the prepared NumPy method section (instead SciPy cKDTree) or modifying that on the code.
I would be grateful for any addressing to related written algorithms (find nearest neighbors) by JAX or helping on my SO issue to implement JAX in the most performant manner on the problem.
