diff --git a/src/fairchem/core/preprocessing/atoms_to_graphs.py b/src/fairchem/core/preprocessing/atoms_to_graphs.py index 8d1618addf..2283c40b8a 100644 --- a/src/fairchem/core/preprocessing/atoms_to_graphs.py +++ b/src/fairchem/core/preprocessing/atoms_to_graphs.py @@ -13,6 +13,7 @@ import ase.io.trajectory import numpy as np import torch +from ase.geometry import wrap_positions from torch_geometric.data import Data from fairchem.core.common.utils import collate @@ -163,10 +164,16 @@ def convert(self, atoms: ase.Atoms, sid=None): """ # set the atomic numbers, positions, and cell + positions = np.array(atoms.get_positions(), copy=True) + pbc = np.array(atoms.pbc, copy=True) + cell = np.array(atoms.get_cell(complete=True), copy=True) + positions = wrap_positions(positions, cell, pbc=pbc, eps=0) + atomic_numbers = torch.Tensor(atoms.get_atomic_numbers()) - positions = torch.Tensor(atoms.get_positions()) - cell = torch.Tensor(np.array(atoms.get_cell())).view(1, 3, 3) + positions = torch.from_numpy(positions).float() + cell = torch.from_numpy(cell).view(1, 3, 3).float() natoms = positions.shape[0] + # initialized to torch.zeros(natoms) if tags missing. # https://wiki.fysik.dtu.dk/ase/_modules/ase/atoms.html#Atoms.get_tags tags = torch.Tensor(atoms.get_tags()) @@ -187,13 +194,16 @@ def convert(self, atoms: ase.Atoms, sid=None): # optionally include other properties if self.r_edges: # run internal functions to get padded indices and distances - split_idx_dist = self._get_neighbors_pymatgen(atoms) + atoms_copy = atoms.copy() + atoms_copy.set_positions(positions) + split_idx_dist = self._get_neighbors_pymatgen(atoms_copy) edge_index, edge_distances, cell_offsets = self._reshape_features( *split_idx_dist ) data.edge_index = edge_index data.cell_offsets = cell_offsets + del atoms_copy if self.r_energy: energy = atoms.get_potential_energy(apply_constraint=False) data.energy = energy diff --git a/tests/core/preprocessing/test_atoms_to_graphs.py b/tests/core/preprocessing/test_atoms_to_graphs.py index ec1c34ab20..5c07a45243 100644 --- a/tests/core/preprocessing/test_atoms_to_graphs.py +++ b/tests/core/preprocessing/test_atoms_to_graphs.py @@ -15,7 +15,7 @@ from ase.neighborlist import NeighborList, NewPrimitiveNeighborList from fairchem.core.preprocessing import AtomsToGraphs - +from fairchem.core.modules.evaluator import min_diff @pytest.fixture(scope="class") def atoms_to_graphs_internals(request) -> None: @@ -110,7 +110,8 @@ def test_convert(self) -> None: # positions act_positions = self.atoms.get_positions() positions = data.pos.numpy() - np.testing.assert_allclose(act_positions, positions) + mindiff = min_diff(act_positions, positions, self.atoms.get_cell(), self.atoms.pbc) + np.testing.assert_allclose(mindiff, 0, atol=1e-6) # check energy value act_energy = self.atoms.get_potential_energy(apply_constraint=False) test_energy = data.energy @@ -142,7 +143,8 @@ def test_convert_all(self) -> None: # positions act_positions = self.atoms.get_positions() positions = data_list[0].pos.numpy() - np.testing.assert_allclose(act_positions, positions) + mindiff = min_diff(act_positions, positions, self.atoms.get_cell(), self.atoms.pbc) + np.testing.assert_allclose(mindiff, 0, atol=1e-6) # check energy value act_energy = self.atoms.get_potential_energy(apply_constraint=False) test_energy = data_list[0].energy