Skip to content

Commit

Permalink
improve the _three_body.py and test_M3GNetCalculator in test_ase.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kenko911 committed Aug 29, 2023
1 parent 120a36e commit 9e1a24b
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 13 deletions.
6 changes: 3 additions & 3 deletions matgl/layers/_three_body.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ def forward(
end_atom_index = torch.gather(graph.edges()[1], 0, line_graph.edges()[1].to(torch.int64))
atoms = self.update_network_atom(node_feat)
end_atom_index = torch.unsqueeze(end_atom_index, 1)
atoms = torch.squeeze(atoms[end_atom_index.long()])
atoms = torch.squeeze(atoms[end_atom_index])
basis = three_basis * atoms
three_cutoff = torch.unsqueeze(three_cutoff, dim=1) # type: ignore
weights = torch.reshape(three_cutoff[torch.stack(list(line_graph.edges()), dim=1)], (-1, 2)) # type: ignore
weights = torch.prod(weights, axis=-1) # type: ignore
weights = three_cutoff[torch.stack(list(line_graph.edges()), dim=1)].view(-1, 2) # type: ignore
weights = torch.prod(weights, dim=-1) # type: ignore
basis = basis * weights[:, None]
new_bonds = scatter_sum(
basis.to(matgl.float_th),
Expand Down
10 changes: 4 additions & 6 deletions matgl/utils/maths.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,11 @@ def get_segment_indices_from_n(ns):
ns: torch.Tensor, the number of atoms/bonds array
Returns:
object:
Returns: segment indices tensor
torch.Tensor: segment indices tensor
"""
B = ns
A = torch.arange(B.size(dim=0))
return A.repeat_interleave(B, dim=0)
segments = torch.zeros(ns.sum(), dtype=matgl.int_th)
segments[ns.cumsum(0)[:-1]] = 1
return segments.cumsum(0)


def get_range_indices_from_n(ns):
Expand Down
6 changes: 2 additions & 4 deletions tests/ext/test_ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,19 @@
from pymatgen.io.ase import AseAtomsAdaptor

from matgl import load_model
from matgl.apps.pes import Potential
from matgl.ext.ase import Atoms2Graph, M3GNetCalculator, MolecularDynamics, Relaxer
from matgl.models import M3GNet


def test_M3GNetCalculator(MoS):
adaptor = AseAtomsAdaptor()
s_ase = adaptor.get_atoms(MoS) # type: ignore
model = M3GNet(element_types=["Mo", "S"], is_intensive=False)
ff = Potential(model=model)
ff = load_model("M3GNet-MP-2021.2.8-PES")
calc = M3GNetCalculator(potential=ff)
s_ase.set_calculator(calc)
assert [s_ase.get_potential_energy().size] == [1]
assert list(s_ase.get_forces().shape) == [2, 3]
assert list(s_ase.get_stress().shape) == [6]
np.testing.assert_allclose(s_ase.get_potential_energy(), -10.312888)


def test_M3GNetCalculator_mol(AcAla3NHMe):
Expand Down

0 comments on commit 9e1a24b

Please sign in to comment.