Skip to content

Commit

Permalink
Make quantities="P" return the two indices from a pair at once
Browse files Browse the repository at this point in the history
  • Loading branch information
Luthaf committed Nov 25, 2024
1 parent e6a7c4a commit 13570a4
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 6 deletions.
2 changes: 1 addition & 1 deletion python/vesin-torch/vesin/torch/_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def compute(
- ``"i"`` to get the index of the first point in the pair
- ``"j"`` to get the index of the second point in the pair
- ``"p"`` to get the indexes of the two points in the pair simultaneously
- ``"P"`` to get the indexes of the two points in the pair simultaneously
- ``"S"`` to get the periodic shift of the pair
- ``"d"`` to get the distance between points in the pair
- ``"D"`` to get the distance vector between points in the pair
Expand Down
6 changes: 3 additions & 3 deletions python/vesin-torch/vesin/torch/metatensor/_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ def compute(self, system: System) -> TensorBlock:
)

# computes neighbor list
(i, j, S, D) = self._nl.compute(
points=points, box=box, periodic=periodic, quantities="ijSD", copy=True
(P, S, D) = self._nl.compute(
points=points, box=box, periodic=periodic, quantities="PSD", copy=True
)

# converts to a suitable TensorBlock format
Expand All @@ -130,7 +130,7 @@ def compute(self, system: System) -> TensorBlock:
"cell_shift_b",
"cell_shift_c",
],
values=torch.hstack([i.reshape(-1, 1), j.reshape(-1, 1), S]),
values=torch.hstack([P, S]),
),
components=self._components,
properties=self._properties,
Expand Down
11 changes: 11 additions & 0 deletions python/vesin/tests/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ def test_neighbors(system, cutoff, vesin_nl):
assert np.allclose(ase_D[ase_sort_indices], vesin_D[vesin_sort_indices])


def test_pairs_output():
atoms = ase.io.read(f"{CURRENT_DIR}/data/diamond.xyz")

calculator = vesin.NeighborList(cutoff=2.0, full_list=True, sorted=False)
i, j, P = calculator.compute(
points=atoms.positions, box=atoms.cell[:], periodic=True, quantities="ijP"
)

assert np.all(np.vstack([i, j]).T == P)


def test_sorting():
atoms = ase.io.read(f"{CURRENT_DIR}/data/diamond.xyz")

Expand Down
4 changes: 2 additions & 2 deletions python/vesin/vesin/_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def compute(
- ``"i"`` to get the index of the first point in the pair
- ``"j"`` to get the index of the second point in the pair
- ``"p"`` to get the indexes of the two points in the pair simultaneously
- ``"P"`` to get the indexes of the two points in the pair simultaneously
- ``"S"`` to get the periodic shift of the pair
- ``"d"`` to get the distance between points in the pair
- ``"D"`` to get the distance vector between points in the pair
Expand Down Expand Up @@ -133,7 +133,7 @@ def compute(

data = []
for quantity in quantities:
if quantity == "p":
if quantity == "P":
if copy:
data.append(pairs.copy())
else:
Expand Down
2 changes: 2 additions & 0 deletions vesin/torch/src/vesin_torch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ std::vector<torch::Tensor> NeighborListHolder::compute(
output.push_back(pairs.index({torch::indexing::Slice(), 0}));
} else if (c == 'j') {
output.push_back(pairs.index({torch::indexing::Slice(), 1}));
} else if (c == 'P') {
output.push_back(pairs);
} else if (c == 'S') {
output.push_back(shifts);
} else if (c == 'd') {
Expand Down

0 comments on commit 13570a4

Please sign in to comment.