Skip to content

Commit

Permalink
Metatensor interface for vesin (#23)
Browse files Browse the repository at this point in the history
Co-authored-by: Guillaume Fraux <guillaume.fraux@epfl.ch>
  • Loading branch information
ceriottm and Luthaf authored Nov 14, 2024
1 parent 491da1f commit 2ec2858
Show file tree
Hide file tree
Showing 10 changed files with 446 additions and 6 deletions.
4 changes: 4 additions & 0 deletions docs/src/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import sys
from datetime import datetime


os.environ["METATENSOR_IMPORT_FOR_SPHINX"] = "1"

ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
sys.path.insert(0, os.path.abspath("."))
sys.path.insert(0, ROOT)
Expand Down Expand Up @@ -63,6 +66,7 @@ def setup(app):

intersphinx_mapping = {
"python": ("https://docs.python.org/3", None),
"metatensor": ("https://docs.metatensor.org/latest", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"torch": ("https://pytorch.org/docs/stable/", None),
"ase": ("https://wiki.fysik.dtu.dk/ase/", None),
Expand Down
1 change: 1 addition & 0 deletions docs/src/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ API Reference
python-api
torch-api
c-api
metatensor


.. toctree::
Expand Down
10 changes: 10 additions & 0 deletions docs/src/metatensor.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
.. _metatensor-api:

Metatensor interface
====================

.. autofunction:: vesin.torch.metatensor.compute_requested_neighbors


.. autoclass:: vesin.torch.metatensor.NeighborList
:members:
161 changes: 161 additions & 0 deletions python/vesin-torch/tests/test_metatensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from typing import Dict, List, Optional

import pytest
import torch
from metatensor.torch import Labels, TensorMap
from metatensor.torch.atomistic import (
MetatensorAtomisticModel,
ModelCapabilities,
ModelMetadata,
ModelOutput,
NeighborListOptions,
System,
)

from vesin.torch.metatensor import NeighborList, compute_requested_neighbors


def test_errors():
positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 2.0]], dtype=torch.float64)
cell = 4 * torch.eye(3, dtype=torch.float64)
system = System(
positions=positions,
cell=cell,
pbc=torch.ones(3, dtype=bool),
types=torch.tensor([6, 8]),
)

options = NeighborListOptions(cutoff=3.5, full_list=True, strict=True)
calculator = NeighborList(options, length_unit="A")

system.pbc[0] = False
message = (
"vesin currently does not support mixed periodic and non-periodic "
"boundary conditions"
)
with pytest.raises(NotImplementedError, match=message):
calculator.compute(system)


def test_script():
positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 2.0]], dtype=torch.float64)
cell = 4 * torch.eye(3, dtype=torch.float64)
system = System(
positions=positions,
cell=cell,
pbc=torch.ones(3, dtype=bool),
types=torch.tensor([6, 8]),
)

options = NeighborListOptions(cutoff=3.5, full_list=True, strict=True)
calculator = torch.jit.script(NeighborList(options, length_unit="A"))
calculator.compute(system)


def test_backward():
positions = torch.tensor(
[[0.0, 0.0, 0.0], [1.0, 1.0, 2.0]], dtype=torch.float64, requires_grad=True
)
cell = (4 * torch.eye(3, dtype=torch.float64)).clone().requires_grad_(True)
system = System(
positions=positions,
cell=cell,
pbc=torch.ones(3, dtype=bool),
types=torch.tensor([6, 8]),
)

options = NeighborListOptions(cutoff=3.5, full_list=True, strict=True)
calculator = NeighborList(options, length_unit="A")
neighbors = calculator.compute(system)

value = ((neighbors.values) ** 2).sum() * torch.linalg.det(cell)
value.backward()

# check there are gradients, and they are not zero
assert positions.grad is not None
assert cell.grad is not None
assert torch.linalg.norm(positions.grad) > 0
assert torch.linalg.norm(cell.grad) > 0


class InnerModule(torch.nn.Module):
def requested_neighbor_lists(self) -> List[NeighborListOptions]:
return [NeighborListOptions(cutoff=3.4, full_list=False, strict=True)]


class OuterModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.inner = InnerModule()

def requested_neighbor_lists(self) -> List[NeighborListOptions]:
return [NeighborListOptions(cutoff=5.2, full_list=True, strict=False)]

def forward(
self,
systems: List[System],
outputs: Dict[str, ModelOutput],
selected_atoms: Optional[Labels],
) -> Dict[str, TensorMap]:
return {}


def test_model():
positions = torch.tensor(
[[0.0, 0.0, 0.0], [1.0, 1.0, 2.0]], dtype=torch.float64, requires_grad=True
)
cell = (4 * torch.eye(3, dtype=torch.float64)).clone().requires_grad_(True)
pbc = torch.ones(3, dtype=bool)
types = torch.tensor([6, 8])
systems = [
System(positions=positions, cell=cell, pbc=pbc, types=types),
System(positions=positions, cell=cell, pbc=pbc, types=types),
]

# Using a "raw" model
model = OuterModule()
compute_requested_neighbors(
systems=systems, system_length_unit="A", model=model, model_length_unit="A"
)

for system in systems:
all_options = system.known_neighbor_lists()
assert len(all_options) == 2
assert all_options[0].requestors() == ["OuterModule"]
assert all_options[0].cutoff == 5.2
assert all_options[1].requestors() == ["OuterModule.inner"]
assert all_options[1].cutoff == 3.4

# Using a MetatensorAtomisticModel model
capabilities = ModelCapabilities(
length_unit="A",
interaction_range=6.0,
supported_devices=["cpu"],
dtype="float64",
)
model = MetatensorAtomisticModel(model.eval(), ModelMetadata(), capabilities)
compute_requested_neighbors(
systems=System(positions=positions, cell=cell, pbc=pbc, types=types),
system_length_unit="A",
model=model,
)

for system in systems:
all_options = system.known_neighbor_lists()
assert len(all_options) == 2
assert all_options[0].requestors() == ["OuterModule"]
assert all_options[0].cutoff == 5.2
assert all_options[1].requestors() == ["OuterModule.inner"]
assert all_options[1].cutoff == 3.4

message = (
"the given `model_length_unit` \\(nm\\) does not match the model "
"capabilities \\(A\\)"
)
with pytest.raises(ValueError, match=message):
compute_requested_neighbors(
systems=System(positions=positions, cell=cell, pbc=pbc, types=types),
system_length_unit="A",
model=model,
model_length_unit="nm",
)
3 changes: 1 addition & 2 deletions python/vesin-torch/vesin/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import importlib.metadata

from ._c_lib import _load_library
from ._neighbors import NeighborList
from ._neighbors import NeighborList # noqa: F401


__version__ = importlib.metadata.version("vesin-torch")
__all__ = ["NeighborList"]

_load_library()
2 changes: 2 additions & 0 deletions python/vesin-torch/vesin/torch/metatensor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from ._model import compute_requested_neighbors # noqa: F401
from ._neighbors import NeighborList # noqa: F401
123 changes: 123 additions & 0 deletions python/vesin-torch/vesin/torch/metatensor/_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from typing import List, Optional, Union

import torch

from ._neighbors import NeighborList


try:
from metatensor.torch.atomistic import (
MetatensorAtomisticModel,
ModelInterface,
NeighborListOptions,
System,
)

_HAS_METATENSOR = True
except ModuleNotFoundError:
_HAS_METATENSOR = False

class MetatensorAtomisticModel:
pass

class ModelInterface:
pass

class NeighborListOptions:
pass

class System:
pass


def compute_requested_neighbors(
systems: Union[List[System], System],
system_length_unit: str,
model: Union[MetatensorAtomisticModel, ModelInterface],
model_length_unit: Optional[str] = None,
):
"""
Compute all neighbors lists requested by the ``model`` through
``requested_neighbor_lists()`` member functions, and store them inside all the
``systems``.
:param systems: Single system or list of systems for which we need to compute the
neighbor lists that the model requires.
:param system_length_unit: unit of length used by the data in ``systems``
:param model: :py:class:`MetatensorAtomisticModel` or any ``torch.nn.Module``
following the :py:class:`ModelInterface`
:param model_length_unit: unit of length used by the model, optional. This is only
required when giving a raw model instead of a
:py:class:`MetatensorAtomisticModel`.
"""

if isinstance(model, MetatensorAtomisticModel):
if model_length_unit is not None:
if model.capabilities().length_unit != model_length_unit:
raise ValueError(
f"the given `model_length_unit` ({model_length_unit}) does not "
f"match the model capabilities ({model.capabilities().length_unit})"
)

all_options = model.requested_neighbor_lists()
elif isinstance(model, torch.nn.Module):
if model_length_unit is None:
raise ValueError(
"`model_length_unit` parameter is required when not "
"using MetatensorAtomisticModel"
)

all_options = []
_get_requested_neighbor_lists(
model, model.__class__.__name__, all_options, model_length_unit
)

if not isinstance(systems, list):
systems = [systems]

for options in all_options:
calculator = NeighborList(options, system_length_unit)
for system in systems:
neighbors = calculator.compute(system)
system.add_neighbor_list(options, neighbors)


def _get_requested_neighbor_lists(
module: torch.nn.Module,
module_name: str,
requested: List[NeighborListOptions],
length_unit: str,
):
"""
Recursively extract the requested neighbor lists from a non-exported metatensor
atomistic model.
"""
if hasattr(module, "requested_neighbor_lists"):
for new_options in module.requested_neighbor_lists():
new_options.add_requestor(module_name)

already_requested = False
for existing in requested:
if existing == new_options:
already_requested = True
for requestor in new_options.requestors():
existing.add_requestor(requestor)

if not already_requested:
if new_options.length_unit not in ["", length_unit]:
raise ValueError(
f"NeighborsListOptions from {module_name} already have a "
f"length unit ('{new_options.length_unit}') which does not "
f"match the model length units ('{length_unit}')"
)

new_options.length_unit = length_unit
requested.append(new_options)

for child_name, child in module.named_children():
_get_requested_neighbor_lists(
module=child,
module_name=module_name + "." + child_name,
requested=requested,
length_unit=length_unit,
)
Loading

0 comments on commit 2ec2858

Please sign in to comment.