Skip to content

Commit

Permalink
Wrap atom coordinates in ase.Atoms preprocessing (#783)
Browse files Browse the repository at this point in the history
* wrap atom coordinates in ase.Atoms preprocessing.

* fix linting.

* adjust atoms_to_graphs test cases.
  • Loading branch information
kyonofx authored Aug 7, 2024
1 parent b2eebb6 commit bc1307f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
16 changes: 13 additions & 3 deletions src/fairchem/core/preprocessing/atoms_to_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import ase.io.trajectory
import numpy as np
import torch
from ase.geometry import wrap_positions
from torch_geometric.data import Data

from fairchem.core.common.utils import collate
Expand Down Expand Up @@ -163,10 +164,16 @@ def convert(self, atoms: ase.Atoms, sid=None):
"""

# set the atomic numbers, positions, and cell
positions = np.array(atoms.get_positions(), copy=True)
pbc = np.array(atoms.pbc, copy=True)
cell = np.array(atoms.get_cell(complete=True), copy=True)
positions = wrap_positions(positions, cell, pbc=pbc, eps=0)

atomic_numbers = torch.Tensor(atoms.get_atomic_numbers())
positions = torch.Tensor(atoms.get_positions())
cell = torch.Tensor(np.array(atoms.get_cell())).view(1, 3, 3)
positions = torch.from_numpy(positions).float()
cell = torch.from_numpy(cell).view(1, 3, 3).float()
natoms = positions.shape[0]

# initialized to torch.zeros(natoms) if tags missing.
# https://wiki.fysik.dtu.dk/ase/_modules/ase/atoms.html#Atoms.get_tags
tags = torch.Tensor(atoms.get_tags())
Expand All @@ -187,13 +194,16 @@ def convert(self, atoms: ase.Atoms, sid=None):
# optionally include other properties
if self.r_edges:
# run internal functions to get padded indices and distances
split_idx_dist = self._get_neighbors_pymatgen(atoms)
atoms_copy = atoms.copy()
atoms_copy.set_positions(positions)
split_idx_dist = self._get_neighbors_pymatgen(atoms_copy)
edge_index, edge_distances, cell_offsets = self._reshape_features(
*split_idx_dist
)

data.edge_index = edge_index
data.cell_offsets = cell_offsets
del atoms_copy
if self.r_energy:
energy = atoms.get_potential_energy(apply_constraint=False)
data.energy = energy
Expand Down
8 changes: 5 additions & 3 deletions tests/core/preprocessing/test_atoms_to_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ase.neighborlist import NeighborList, NewPrimitiveNeighborList

from fairchem.core.preprocessing import AtomsToGraphs

from fairchem.core.modules.evaluator import min_diff

@pytest.fixture(scope="class")
def atoms_to_graphs_internals(request) -> None:
Expand Down Expand Up @@ -110,7 +110,8 @@ def test_convert(self) -> None:
# positions
act_positions = self.atoms.get_positions()
positions = data.pos.numpy()
np.testing.assert_allclose(act_positions, positions)
mindiff = min_diff(act_positions, positions, self.atoms.get_cell(), self.atoms.pbc)
np.testing.assert_allclose(mindiff, 0, atol=1e-6)
# check energy value
act_energy = self.atoms.get_potential_energy(apply_constraint=False)
test_energy = data.energy
Expand Down Expand Up @@ -142,7 +143,8 @@ def test_convert_all(self) -> None:
# positions
act_positions = self.atoms.get_positions()
positions = data_list[0].pos.numpy()
np.testing.assert_allclose(act_positions, positions)
mindiff = min_diff(act_positions, positions, self.atoms.get_cell(), self.atoms.pbc)
np.testing.assert_allclose(mindiff, 0, atol=1e-6)
# check energy value
act_energy = self.atoms.get_potential_energy(apply_constraint=False)
test_energy = data_list[0].energy
Expand Down

0 comments on commit bc1307f

Please sign in to comment.