Skip to content

Commit

Permalink
Full CPU implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Raimondas Galvelis committed Jun 2, 2022
1 parent 7ccd5b6 commit 5144cce
Showing 1 changed file with 32 additions and 3 deletions.
35 changes: 32 additions & 3 deletions src/neighbors/getNeighborPairsCPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@

using std::tuple;
using torch::div;
using torch::full;
using torch::index_select;
using torch::indexing::Slice;
using torch::arange;
using torch::frobenius_norm;
using torch::kInt32;
using torch::Scalar;
using torch::stack;
using torch::hstack;
using torch::vstack;
using torch::Tensor;

static tuple<Tensor, Tensor> forward(const Tensor& positions,
Expand All @@ -20,6 +23,12 @@ static tuple<Tensor, Tensor> forward(const Tensor& positions,
TORCH_CHECK(positions.size(1) == 3, "Expected the 2nd dimension size of \"positions\" to be 3");
TORCH_CHECK(positions.is_contiguous(), "Expected \"positions\" to be contiguous");

TORCH_CHECK(cutoff.to<double>() > 0, "Expected \"cutoff\" to be positive");

const int max_num_neighbors_ = max_num_neighbors.to<int>();
TORCH_CHECK(max_num_neighbors_ > 0 || max_num_neighbors_ == -1,
"Expected \"max_num_neighbors\" to be positive or equal to -1");

const int num_atoms = positions.size(0);
const int num_pairs = num_atoms * (num_atoms - 1) / 2;

Expand All @@ -28,9 +37,29 @@ static tuple<Tensor, Tensor> forward(const Tensor& positions,
rows -= (rows * (rows - 1) > 2 * indices).to(kInt32);
const Tensor columns = indices - div(rows * (rows - 1), 2, "floor");

const Tensor neighbors = stack({rows, columns});
Tensor neighbors = vstack({rows, columns});
const Tensor vectors = index_select(positions, 0, rows) - index_select(positions, 0, columns);
const Tensor distances = frobenius_norm(vectors, 1);
Tensor distances = frobenius_norm(vectors, 1);

if (max_num_neighbors_ == -1) {
const Tensor mask = distances > cutoff;
neighbors.index_put_({Slice(), mask}, -1);
distances.index_put_({mask}, NAN);

} else {
const Tensor mask = distances <= cutoff;
neighbors = neighbors.index({Slice(), mask});
distances = distances.index({mask});

const int num_pad = num_atoms * max_num_neighbors_ - distances.size(0);
TORCH_CHECK(num_pad >= 0,
"The maximum number of pairs has been exceed! Increase \"max_num_neighbors\"");

if (num_pad > 0) {
neighbors = hstack({neighbors, full({2, num_pad}, -1, neighbors.options())});
distances = hstack({distances, full({num_pad}, NAN, distances.options())});
}
}

return {neighbors, distances};
}
Expand Down

0 comments on commit 5144cce

Please sign in to comment.