-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Nearest neighbor operation #58
Merged
Merged
Changes from 18 commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
6f9b785
Move code from TorchMD-NET
be7a305
Enable building
27f37ed
Fix import
10a5c0a
Enable tests
ffd6c43
Harmonize naming
62db847
Fix naming
9838194
Refatoring for documentation
225fc59
Merge branch 'master' into neighbors
4af6699
Select the algorithm based on the num_max_neighbors
bd67dfe
Support pre-Kepler GPUs
65e9a22
Full CPU implementation
eb1c6ad
Improve tests
5bad2af
Test all the cases
8204bf7
Speed up the tests
1d90089
Guard from an overflow
aea4f01
Add documentation
640bc73
Restrucutre the files
8db4677
Add the missing files
b88abd0
Add a removal note
d7d1ce8
Improve documentation
3df1b0c
Add a few examples
e42a10f
Fix the path of the test
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,5 @@ | ||
''' | ||
High-performance PyTorch operations for neural network potentials | ||
''' | ||
|
||
from NNPOps.OptimizedTorchANI import OptimizedTorchANI |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import numpy as np | ||
import pytest | ||
import torch as pt | ||
from NNPOps.neighbors import getNeighborPairs | ||
|
||
|
||
def sort_neighbors(neighbors, distances): | ||
i_sorted = np.lexsort(neighbors)[::-1] | ||
return neighbors[:, i_sorted], distances[i_sorted] | ||
|
||
def resize_neighbors(neighbors, distances, num_neighbors): | ||
|
||
new_neighbors = np.full((2, num_neighbors), -1, dtype=neighbors.dtype) | ||
new_distances = np.full((num_neighbors), np.nan, dtype=distances.dtype) | ||
|
||
if num_neighbors < neighbors.shape[1]: | ||
assert np.all(neighbors[:, num_neighbors:] == -1) | ||
assert np.all(np.isnan(distances[num_neighbors:])) | ||
new_neighbors = neighbors[:, :num_neighbors] | ||
new_distances = distances[:num_neighbors] | ||
else: | ||
num_neighbors = neighbors.shape[1] | ||
new_neighbors[:, :num_neighbors] = neighbors | ||
new_distances[:num_neighbors] = distances | ||
|
||
return new_neighbors, new_distances | ||
|
||
@pytest.mark.parametrize('device', ['cpu', 'cuda']) | ||
@pytest.mark.parametrize('dtype', [pt.float32, pt.float64]) | ||
@pytest.mark.parametrize('num_atoms', [1, 2, 3, 4, 5, 10, 100, 1000]) | ||
@pytest.mark.parametrize('cutoff', [1, 10, 100]) | ||
@pytest.mark.parametrize('all_pairs', [True, False]) | ||
def test_neighbors(device, dtype, num_atoms, cutoff, all_pairs): | ||
|
||
if not pt.cuda.is_available() and device == 'cuda': | ||
pytest.skip('No GPU') | ||
|
||
# Generate random positions | ||
positions = 10 * pt.randn((num_atoms, 3), device=device, dtype=dtype) | ||
|
||
# Get neighbor pairs | ||
ref_neighbors = np.vstack(np.tril_indices(num_atoms, -1)) | ||
ref_positions = positions.cpu().numpy() | ||
ref_distances = np.linalg.norm(ref_positions[ref_neighbors[0]] - ref_positions[ref_neighbors[1]], axis=1) | ||
|
||
# Filter the neighbor pairs | ||
mask = ref_distances > cutoff | ||
ref_neighbors[:, mask] = -1 | ||
ref_distances[mask] = np.nan | ||
|
||
# Find the number of neighbors | ||
num_neighbors = np.count_nonzero(np.logical_not(np.isnan(ref_distances))) | ||
max_num_neighbors = -1 if all_pairs else max(int(np.ceil(num_neighbors / num_atoms)), 1) | ||
|
||
# Compute results | ||
neighbors, distances = getNeighborPairs(positions, cutoff=cutoff, max_num_neighbors=max_num_neighbors) | ||
|
||
# Check device | ||
assert neighbors.device == positions.device | ||
assert distances.device == positions.device | ||
|
||
# Check types | ||
assert neighbors.dtype == pt.int32 | ||
assert distances.dtype == dtype | ||
|
||
# Covert the results | ||
neighbors = neighbors.cpu().numpy() | ||
distances = distances.cpu().numpy() | ||
|
||
if not all_pairs: | ||
# Sort the neighbors | ||
# NOTE: GPU returns the neighbor in a non-deterministic order | ||
ref_neighbors, ref_distances = sort_neighbors(ref_neighbors, ref_distances) | ||
neighbors, distances = sort_neighbors(neighbors, distances) | ||
|
||
# Resize the reference | ||
ref_neighbors, ref_distances = resize_neighbors(ref_neighbors, ref_distances, num_atoms * max_num_neighbors) | ||
|
||
assert np.all(ref_neighbors == neighbors) | ||
assert np.allclose(ref_distances, distances, equal_nan=True) | ||
|
||
@pytest.mark.parametrize('device', ['cpu', 'cuda']) | ||
@pytest.mark.parametrize('dtype', [pt.float32, pt.float64]) | ||
def test_too_many_neighbors(device, dtype): | ||
|
||
if not pt.cuda.is_available() and device == 'cuda': | ||
pytest.skip('No GPU') | ||
|
||
# 4 points result into 6 pairs, but there is a storage just for 4. | ||
with pytest.raises(RuntimeError): | ||
positions = pt.zeros((4, 3,), device=device, dtype=dtype) | ||
getNeighborPairs(positions, cutoff=1, max_num_neighbors=1) | ||
pt.cuda.synchronize() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
''' | ||
Neighbor operations | ||
''' | ||
|
||
from NNPOps.neighbors.getNeighborPairs import getNeighborPairs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from torch import ops, Tensor | ||
from typing import Tuple | ||
|
||
|
||
def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int = -1) -> Tuple[Tensor, Tensor]: | ||
''' | ||
Returns indices and distances of atom pairs withing a given cutoff distance. | ||
|
||
If `max_num_neighbors == -1` (default), all the atom pairs are returned, | ||
(i.e. `num_atoms * (num_atoms + 1) / 2` pairs). If `max_num_neighbors > 0`, | ||
only `num_atoms * max_num_neighbors` pairs are returned. | ||
|
||
Parameters | ||
---------- | ||
positions: `torch.Tensor` | ||
Atomic positions. The tensor shape has to be `(num_atoms, 3)` and | ||
data type has to be`torch.float32` or `torch.float64`. | ||
cutoff: float | ||
Maximum distance between atom pairs. | ||
max_num_neighbors: int, optional | ||
Maximum number of neighbors per atom. If set to `-1` (default), | ||
all possible combinations of atom pairs are included. | ||
|
||
Returns | ||
------- | ||
neighbors: `torch.Tensor` | ||
Atom pair indices. The shape of the tensor is `(2, num_pairs)`. | ||
If an atom pair is separated by a larger distance than the cutoff, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Defined |
||
the indices are set to `-1`. | ||
|
||
distances: `torch.Tensor` | ||
Atom pair distances. The shape of the tensor is `(num_pairs)`. | ||
If an atom pair is separated by a larger distance than the cutoff, | ||
the distance is set to `NaN`. | ||
|
||
Exceptions | ||
---------- | ||
If `max_num_neighbors > 0` and too small, `RuntimeError` is raised. | ||
|
||
Note | ||
---- | ||
The CUDA implementation returns the atom pairs in non-determinist order, | ||
if `max_num_neighbors > 0`. | ||
''' | ||
|
||
return ops.neighbors.getNeighborPairs(positions, cutoff, max_num_neighbors) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
#include <torch/extension.h> | ||
#include <tuple> | ||
|
||
using std::tuple; | ||
using torch::div; | ||
using torch::full; | ||
using torch::index_select; | ||
using torch::indexing::Slice; | ||
using torch::arange; | ||
using torch::frobenius_norm; | ||
using torch::kInt32; | ||
using torch::Scalar; | ||
using torch::hstack; | ||
using torch::vstack; | ||
using torch::Tensor; | ||
|
||
static tuple<Tensor, Tensor> forward(const Tensor& positions, | ||
const Scalar& cutoff, | ||
const Scalar& max_num_neighbors) { | ||
|
||
TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions"); | ||
TORCH_CHECK(positions.size(0) > 0, "Expected the 1nd dimension size of \"positions\" to be more than 0"); | ||
TORCH_CHECK(positions.size(1) == 3, "Expected the 2nd dimension size of \"positions\" to be 3"); | ||
TORCH_CHECK(positions.is_contiguous(), "Expected \"positions\" to be contiguous"); | ||
|
||
TORCH_CHECK(cutoff.to<double>() > 0, "Expected \"cutoff\" to be positive"); | ||
|
||
const int max_num_neighbors_ = max_num_neighbors.to<int>(); | ||
TORCH_CHECK(max_num_neighbors_ > 0 || max_num_neighbors_ == -1, | ||
"Expected \"max_num_neighbors\" to be positive or equal to -1"); | ||
|
||
const int num_atoms = positions.size(0); | ||
const int num_pairs = num_atoms * (num_atoms - 1) / 2; | ||
|
||
const Tensor indices = arange(0, num_pairs, positions.options().dtype(kInt32)); | ||
Tensor rows = (((8 * indices + 1).sqrt() + 1) / 2).floor().to(kInt32); | ||
rows -= (rows * (rows - 1) > 2 * indices).to(kInt32); | ||
const Tensor columns = indices - div(rows * (rows - 1), 2, "floor"); | ||
|
||
Tensor neighbors = vstack({rows, columns}); | ||
const Tensor vectors = index_select(positions, 0, rows) - index_select(positions, 0, columns); | ||
Tensor distances = frobenius_norm(vectors, 1); | ||
|
||
if (max_num_neighbors_ == -1) { | ||
const Tensor mask = distances > cutoff; | ||
neighbors.index_put_({Slice(), mask}, -1); | ||
distances.index_put_({mask}, NAN); | ||
|
||
} else { | ||
const Tensor mask = distances <= cutoff; | ||
neighbors = neighbors.index({Slice(), mask}); | ||
distances = distances.index({mask}); | ||
|
||
const int num_pad = num_atoms * max_num_neighbors_ - distances.size(0); | ||
TORCH_CHECK(num_pad >= 0, | ||
"The maximum number of pairs has been exceed! Increase \"max_num_neighbors\""); | ||
|
||
if (num_pad > 0) { | ||
neighbors = hstack({neighbors, full({2, num_pad}, -1, neighbors.options())}); | ||
distances = hstack({distances, full({num_pad}, NAN, distances.options())}); | ||
} | ||
} | ||
|
||
return {neighbors, distances}; | ||
} | ||
|
||
TORCH_LIBRARY_IMPL(neighbors, CPU, m) { | ||
m.impl("getNeighborPairs", &forward); | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit ambiguous. Is it saying that if
max_num_neighbors
is -1 then the cutoff is ignored and all pairs are returned regardless of distance? Or is it simply returning the result in a different shape tensor?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I improved the text and added a few example.