Skip to content

Commit

Permalink
Fix incorrect decorator in NNPOps (#93)
Browse files Browse the repository at this point in the history
* Fix incorrect decorator in NNPOps

* Added a test to check if getNeighborPairs is jit.script compatible
  • Loading branch information
RaulPPelaez authored Mar 20, 2023
1 parent c8690ac commit 4b911e5
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
21 changes: 20 additions & 1 deletion src/pytorch/neighbors/TestNeighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest
import torch as pt
from NNPOps.neighbors import getNeighborPairs

import tempfile

def sort_neighbors(neighbors, deltas, distances):
i_sorted = np.lexsort(neighbors)[::-1]
Expand Down Expand Up @@ -218,3 +218,22 @@ def test_periodic_neighbors(device, dtype):
assert np.all(ref_neighbors == neighbors)
assert np.allclose(ref_deltas, deltas, equal_nan=True)
assert np.allclose(ref_distances, distances, equal_nan=True)


@pytest.mark.parametrize('device', ['cpu', 'cuda'])
@pytest.mark.parametrize('dtype', [pt.float32, pt.float64])
def test_jit_script_compatible(device, dtype):

class ForceModule(pt.nn.Module):

def forward(self, positions):

neighbors, deltas, distances = getNeighborPairs(positions, cutoff=1.0)
mask = pt.isnan(distances)
distances = distances[~mask]
return pt.sum(distances**2)

with tempfile.NamedTemporaryFile() as temp:
model = ForceModule()
module = pt.jit.script(model)
module.save(temp.name)
2 changes: 1 addition & 1 deletion src/pytorch/neighbors/getNeighborPairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional, Tuple


def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int = -1, box_vectors: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int = -1, box_vectors: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]:
'''
Returns indices and distances of atom pairs within a given cutoff distance.
Expand Down

0 comments on commit 4b911e5

Please sign in to comment.