-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Metatensor interface for vesin (#23)
Co-authored-by: Guillaume Fraux <guillaume.fraux@epfl.ch>
- Loading branch information
Showing
10 changed files
with
446 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -405,6 +405,7 @@ API Reference | |
python-api | ||
torch-api | ||
c-api | ||
metatensor | ||
|
||
|
||
.. toctree:: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
.. _metatensor-api: | ||
|
||
Metatensor interface | ||
==================== | ||
|
||
.. autofunction:: vesin.torch.metatensor.compute_requested_neighbors | ||
|
||
|
||
.. autoclass:: vesin.torch.metatensor.NeighborList | ||
:members: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
from typing import Dict, List, Optional | ||
|
||
import pytest | ||
import torch | ||
from metatensor.torch import Labels, TensorMap | ||
from metatensor.torch.atomistic import ( | ||
MetatensorAtomisticModel, | ||
ModelCapabilities, | ||
ModelMetadata, | ||
ModelOutput, | ||
NeighborListOptions, | ||
System, | ||
) | ||
|
||
from vesin.torch.metatensor import NeighborList, compute_requested_neighbors | ||
|
||
|
||
def test_errors(): | ||
positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 2.0]], dtype=torch.float64) | ||
cell = 4 * torch.eye(3, dtype=torch.float64) | ||
system = System( | ||
positions=positions, | ||
cell=cell, | ||
pbc=torch.ones(3, dtype=bool), | ||
types=torch.tensor([6, 8]), | ||
) | ||
|
||
options = NeighborListOptions(cutoff=3.5, full_list=True, strict=True) | ||
calculator = NeighborList(options, length_unit="A") | ||
|
||
system.pbc[0] = False | ||
message = ( | ||
"vesin currently does not support mixed periodic and non-periodic " | ||
"boundary conditions" | ||
) | ||
with pytest.raises(NotImplementedError, match=message): | ||
calculator.compute(system) | ||
|
||
|
||
def test_script(): | ||
positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 2.0]], dtype=torch.float64) | ||
cell = 4 * torch.eye(3, dtype=torch.float64) | ||
system = System( | ||
positions=positions, | ||
cell=cell, | ||
pbc=torch.ones(3, dtype=bool), | ||
types=torch.tensor([6, 8]), | ||
) | ||
|
||
options = NeighborListOptions(cutoff=3.5, full_list=True, strict=True) | ||
calculator = torch.jit.script(NeighborList(options, length_unit="A")) | ||
calculator.compute(system) | ||
|
||
|
||
def test_backward(): | ||
positions = torch.tensor( | ||
[[0.0, 0.0, 0.0], [1.0, 1.0, 2.0]], dtype=torch.float64, requires_grad=True | ||
) | ||
cell = (4 * torch.eye(3, dtype=torch.float64)).clone().requires_grad_(True) | ||
system = System( | ||
positions=positions, | ||
cell=cell, | ||
pbc=torch.ones(3, dtype=bool), | ||
types=torch.tensor([6, 8]), | ||
) | ||
|
||
options = NeighborListOptions(cutoff=3.5, full_list=True, strict=True) | ||
calculator = NeighborList(options, length_unit="A") | ||
neighbors = calculator.compute(system) | ||
|
||
value = ((neighbors.values) ** 2).sum() * torch.linalg.det(cell) | ||
value.backward() | ||
|
||
# check there are gradients, and they are not zero | ||
assert positions.grad is not None | ||
assert cell.grad is not None | ||
assert torch.linalg.norm(positions.grad) > 0 | ||
assert torch.linalg.norm(cell.grad) > 0 | ||
|
||
|
||
class InnerModule(torch.nn.Module): | ||
def requested_neighbor_lists(self) -> List[NeighborListOptions]: | ||
return [NeighborListOptions(cutoff=3.4, full_list=False, strict=True)] | ||
|
||
|
||
class OuterModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.inner = InnerModule() | ||
|
||
def requested_neighbor_lists(self) -> List[NeighborListOptions]: | ||
return [NeighborListOptions(cutoff=5.2, full_list=True, strict=False)] | ||
|
||
def forward( | ||
self, | ||
systems: List[System], | ||
outputs: Dict[str, ModelOutput], | ||
selected_atoms: Optional[Labels], | ||
) -> Dict[str, TensorMap]: | ||
return {} | ||
|
||
|
||
def test_model(): | ||
positions = torch.tensor( | ||
[[0.0, 0.0, 0.0], [1.0, 1.0, 2.0]], dtype=torch.float64, requires_grad=True | ||
) | ||
cell = (4 * torch.eye(3, dtype=torch.float64)).clone().requires_grad_(True) | ||
pbc = torch.ones(3, dtype=bool) | ||
types = torch.tensor([6, 8]) | ||
systems = [ | ||
System(positions=positions, cell=cell, pbc=pbc, types=types), | ||
System(positions=positions, cell=cell, pbc=pbc, types=types), | ||
] | ||
|
||
# Using a "raw" model | ||
model = OuterModule() | ||
compute_requested_neighbors( | ||
systems=systems, system_length_unit="A", model=model, model_length_unit="A" | ||
) | ||
|
||
for system in systems: | ||
all_options = system.known_neighbor_lists() | ||
assert len(all_options) == 2 | ||
assert all_options[0].requestors() == ["OuterModule"] | ||
assert all_options[0].cutoff == 5.2 | ||
assert all_options[1].requestors() == ["OuterModule.inner"] | ||
assert all_options[1].cutoff == 3.4 | ||
|
||
# Using a MetatensorAtomisticModel model | ||
capabilities = ModelCapabilities( | ||
length_unit="A", | ||
interaction_range=6.0, | ||
supported_devices=["cpu"], | ||
dtype="float64", | ||
) | ||
model = MetatensorAtomisticModel(model.eval(), ModelMetadata(), capabilities) | ||
compute_requested_neighbors( | ||
systems=System(positions=positions, cell=cell, pbc=pbc, types=types), | ||
system_length_unit="A", | ||
model=model, | ||
) | ||
|
||
for system in systems: | ||
all_options = system.known_neighbor_lists() | ||
assert len(all_options) == 2 | ||
assert all_options[0].requestors() == ["OuterModule"] | ||
assert all_options[0].cutoff == 5.2 | ||
assert all_options[1].requestors() == ["OuterModule.inner"] | ||
assert all_options[1].cutoff == 3.4 | ||
|
||
message = ( | ||
"the given `model_length_unit` \\(nm\\) does not match the model " | ||
"capabilities \\(A\\)" | ||
) | ||
with pytest.raises(ValueError, match=message): | ||
compute_requested_neighbors( | ||
systems=System(positions=positions, cell=cell, pbc=pbc, types=types), | ||
system_length_unit="A", | ||
model=model, | ||
model_length_unit="nm", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,9 @@ | ||
import importlib.metadata | ||
|
||
from ._c_lib import _load_library | ||
from ._neighbors import NeighborList | ||
from ._neighbors import NeighborList # noqa: F401 | ||
|
||
|
||
__version__ = importlib.metadata.version("vesin-torch") | ||
__all__ = ["NeighborList"] | ||
|
||
_load_library() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from ._model import compute_requested_neighbors # noqa: F401 | ||
from ._neighbors import NeighborList # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
from typing import List, Optional, Union | ||
|
||
import torch | ||
|
||
from ._neighbors import NeighborList | ||
|
||
|
||
try: | ||
from metatensor.torch.atomistic import ( | ||
MetatensorAtomisticModel, | ||
ModelInterface, | ||
NeighborListOptions, | ||
System, | ||
) | ||
|
||
_HAS_METATENSOR = True | ||
except ModuleNotFoundError: | ||
_HAS_METATENSOR = False | ||
|
||
class MetatensorAtomisticModel: | ||
pass | ||
|
||
class ModelInterface: | ||
pass | ||
|
||
class NeighborListOptions: | ||
pass | ||
|
||
class System: | ||
pass | ||
|
||
|
||
def compute_requested_neighbors( | ||
systems: Union[List[System], System], | ||
system_length_unit: str, | ||
model: Union[MetatensorAtomisticModel, ModelInterface], | ||
model_length_unit: Optional[str] = None, | ||
): | ||
""" | ||
Compute all neighbors lists requested by the ``model`` through | ||
``requested_neighbor_lists()`` member functions, and store them inside all the | ||
``systems``. | ||
:param systems: Single system or list of systems for which we need to compute the | ||
neighbor lists that the model requires. | ||
:param system_length_unit: unit of length used by the data in ``systems`` | ||
:param model: :py:class:`MetatensorAtomisticModel` or any ``torch.nn.Module`` | ||
following the :py:class:`ModelInterface` | ||
:param model_length_unit: unit of length used by the model, optional. This is only | ||
required when giving a raw model instead of a | ||
:py:class:`MetatensorAtomisticModel`. | ||
""" | ||
|
||
if isinstance(model, MetatensorAtomisticModel): | ||
if model_length_unit is not None: | ||
if model.capabilities().length_unit != model_length_unit: | ||
raise ValueError( | ||
f"the given `model_length_unit` ({model_length_unit}) does not " | ||
f"match the model capabilities ({model.capabilities().length_unit})" | ||
) | ||
|
||
all_options = model.requested_neighbor_lists() | ||
elif isinstance(model, torch.nn.Module): | ||
if model_length_unit is None: | ||
raise ValueError( | ||
"`model_length_unit` parameter is required when not " | ||
"using MetatensorAtomisticModel" | ||
) | ||
|
||
all_options = [] | ||
_get_requested_neighbor_lists( | ||
model, model.__class__.__name__, all_options, model_length_unit | ||
) | ||
|
||
if not isinstance(systems, list): | ||
systems = [systems] | ||
|
||
for options in all_options: | ||
calculator = NeighborList(options, system_length_unit) | ||
for system in systems: | ||
neighbors = calculator.compute(system) | ||
system.add_neighbor_list(options, neighbors) | ||
|
||
|
||
def _get_requested_neighbor_lists( | ||
module: torch.nn.Module, | ||
module_name: str, | ||
requested: List[NeighborListOptions], | ||
length_unit: str, | ||
): | ||
""" | ||
Recursively extract the requested neighbor lists from a non-exported metatensor | ||
atomistic model. | ||
""" | ||
if hasattr(module, "requested_neighbor_lists"): | ||
for new_options in module.requested_neighbor_lists(): | ||
new_options.add_requestor(module_name) | ||
|
||
already_requested = False | ||
for existing in requested: | ||
if existing == new_options: | ||
already_requested = True | ||
for requestor in new_options.requestors(): | ||
existing.add_requestor(requestor) | ||
|
||
if not already_requested: | ||
if new_options.length_unit not in ["", length_unit]: | ||
raise ValueError( | ||
f"NeighborsListOptions from {module_name} already have a " | ||
f"length unit ('{new_options.length_unit}') which does not " | ||
f"match the model length units ('{length_unit}')" | ||
) | ||
|
||
new_options.length_unit = length_unit | ||
requested.append(new_options) | ||
|
||
for child_name, child in module.named_children(): | ||
_get_requested_neighbor_lists( | ||
module=child, | ||
module_name=module_name + "." + child_name, | ||
requested=requested, | ||
length_unit=length_unit, | ||
) |
Oops, something went wrong.