diff --git a/matgl/layers/_three_body.py b/matgl/layers/_three_body.py index 4f155ce4..2afec90e 100644 --- a/matgl/layers/_three_body.py +++ b/matgl/layers/_three_body.py @@ -53,11 +53,11 @@ def forward( end_atom_index = torch.gather(graph.edges()[1], 0, line_graph.edges()[1].to(torch.int64)) atoms = self.update_network_atom(node_feat) end_atom_index = torch.unsqueeze(end_atom_index, 1) - atoms = torch.squeeze(atoms[end_atom_index.long()]) + atoms = torch.squeeze(atoms[end_atom_index]) basis = three_basis * atoms three_cutoff = torch.unsqueeze(three_cutoff, dim=1) # type: ignore - weights = torch.reshape(three_cutoff[torch.stack(list(line_graph.edges()), dim=1)], (-1, 2)) # type: ignore - weights = torch.prod(weights, axis=-1) # type: ignore + weights = three_cutoff[torch.stack(list(line_graph.edges()), dim=1)].view(-1, 2) # type: ignore + weights = torch.prod(weights, dim=-1) # type: ignore basis = basis * weights[:, None] new_bonds = scatter_sum( basis.to(matgl.float_th), diff --git a/matgl/utils/maths.py b/matgl/utils/maths.py index fe61f45f..71991d2f 100644 --- a/matgl/utils/maths.py +++ b/matgl/utils/maths.py @@ -101,13 +101,11 @@ def get_segment_indices_from_n(ns): ns: torch.Tensor, the number of atoms/bonds array Returns: - object: - - Returns: segment indices tensor + torch.Tensor: segment indices tensor """ - B = ns - A = torch.arange(B.size(dim=0)) - return A.repeat_interleave(B, dim=0) + segments = torch.zeros(ns.sum(), dtype=matgl.int_th) + segments[ns.cumsum(0)[:-1]] = 1 + return segments.cumsum(0) def get_range_indices_from_n(ns): diff --git a/tests/ext/test_ase.py b/tests/ext/test_ase.py index da583602..b07e00a0 100644 --- a/tests/ext/test_ase.py +++ b/tests/ext/test_ase.py @@ -8,21 +8,19 @@ from pymatgen.io.ase import AseAtomsAdaptor from matgl import load_model -from matgl.apps.pes import Potential from matgl.ext.ase import Atoms2Graph, M3GNetCalculator, MolecularDynamics, Relaxer -from matgl.models import M3GNet def test_M3GNetCalculator(MoS): adaptor = AseAtomsAdaptor() s_ase = adaptor.get_atoms(MoS) # type: ignore - model = M3GNet(element_types=["Mo", "S"], is_intensive=False) - ff = Potential(model=model) + ff = load_model("M3GNet-MP-2021.2.8-PES") calc = M3GNetCalculator(potential=ff) s_ase.set_calculator(calc) assert [s_ase.get_potential_energy().size] == [1] assert list(s_ase.get_forces().shape) == [2, 3] assert list(s_ase.get_stress().shape) == [6] + np.testing.assert_allclose(s_ase.get_potential_energy(), -10.312888) def test_M3GNetCalculator_mol(AcAla3NHMe):