Skip to content

Commit

Permalink
skip failing m3gnet tests with TODO and link to matgl issue
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Mar 7, 2024
1 parent 6a9d3c5 commit 4fcbc58
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 12 deletions.
20 changes: 10 additions & 10 deletions pymatgen/transformations/advanced_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,9 @@ def __init__(
if max_cell_size and max_disordered_sites:
raise ValueError("Cannot set both max_cell_size and max_disordered_sites!")

def apply_transformation(self, structure: Structure, return_ranked_list: bool | int = False):
def apply_transformation(
self, structure: Structure, return_ranked_list: bool | int = False
) -> Structure | list[dict]:
"""Returns either a single ordered structure or a sequence of all ordered
structures.
Expand Down Expand Up @@ -879,7 +881,7 @@ def apply_transformation(
# remove dummy species and replace Spin.up or Spin.down
# with spin magnitudes given in mag_species_spin arg
alls = self._remove_dummy_species(alls)
alls = self._add_spin_magnitudes(alls)
alls = self._add_spin_magnitudes(alls) # type: ignore[arg-type]
else:
for idx in range(len(alls)):
alls[idx]["structure"] = self._remove_dummy_species(alls[idx]["structure"])
Expand All @@ -891,7 +893,7 @@ def apply_transformation(
num_to_return = 1

if num_to_return == 1 or not return_ranked_list:
return alls[0]["structure"] if num_to_return else alls
return alls[0]["structure"] if num_to_return else alls # type: ignore[return-value]

# remove duplicate structures and group according to energy model
matcher = StructureMatcher(comparator=SpinComparator())
Expand Down Expand Up @@ -1010,11 +1012,10 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
Args:
structure (Structure): Input structure to dope
return_ranked_list (bool | int, optional): If return_ranked_list is int, that number of structures.
is returned. If False, only the single lowest energy structure is returned. Defaults to False.
Returns:
[{"structure": Structure, "energy": float}]
list[dict] | Structure: each dict has shape {"structure": Structure, "energy": float}.
"""
comp = structure.composition
logger.info(f"Composition: {comp}")
Expand Down Expand Up @@ -1059,7 +1060,7 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
logger.info(f"{lengths=}")
logger.info(f"{scaling=}")

all_structures = []
all_structures: list[dict] = []
trafo = EnumerateStructureTransformation(**self.kwargs)

for sp in compatible_species:
Expand Down Expand Up @@ -1131,10 +1132,9 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
}
)

ss = trafo.apply_transformation(supercell, return_ranked_list=self.max_structures_per_enum)
logger.info(f"{len(ss)} distinct structures")
all_structures.extend(ss)

structs = trafo.apply_transformation(supercell, return_ranked_list=self.max_structures_per_enum)
logger.info(f"{len(structs)} distinct structures")
all_structures.extend(structs)
logger.info(f"Total {len(all_structures)} doped structures")
if return_ranked_list:
return all_structures[:return_ranked_list]
Expand Down
4 changes: 4 additions & 0 deletions tests/core/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1705,6 +1705,7 @@ def test_relax_ase_opt_kwargs(self):
assert traj[0] != traj[-1]
assert os.path.isfile(traj_file)

@pytest.mark.skip("TODO remove skip once https://github.com/materialsvirtuallab/matgl/issues/238 is resolved")
def test_calculate_m3gnet(self):
pytest.importorskip("matgl")
calculator = self.get_structure("Si").calculate()
Expand All @@ -1716,6 +1717,7 @@ def test_calculate_m3gnet(self):
assert np.linalg.norm(calculator.results["forces"]) == approx(7.8123485e-06, abs=0.2)
assert np.linalg.norm(calculator.results["stress"]) == approx(1.7861567, abs=2)

@pytest.mark.skip("TODO remove skip once https://github.com/materialsvirtuallab/matgl/issues/238 is resolved")
def test_relax_m3gnet(self):
pytest.importorskip("matgl")
struct = self.get_structure("Si")
Expand All @@ -1726,6 +1728,7 @@ def test_relax_m3gnet(self):
actual = relaxed.dynamics[key]
assert actual == val, f"expected {key} to be {val}, {actual=}"

@pytest.mark.skip("TODO remove skip once https://github.com/materialsvirtuallab/matgl/issues/238 is resolved")
def test_relax_m3gnet_fixed_lattice(self):
pytest.importorskip("matgl")
struct = self.get_structure("Si")
Expand All @@ -1734,6 +1737,7 @@ def test_relax_m3gnet_fixed_lattice(self):
assert hasattr(relaxed, "calc")
assert relaxed.dynamics["optimizer"] == "BFGS"

@pytest.mark.skip("TODO remove skip once https://github.com/materialsvirtuallab/matgl/issues/238 is resolved")
def test_relax_m3gnet_with_traj(self):
pytest.importorskip("matgl")
struct = self.get_structure("Si")
Expand Down
6 changes: 4 additions & 2 deletions tests/transformations/test_advanced_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def test_apply_transformation(self):
for s in alls:
assert "energy" not in s

@pytest.mark.skip("TODO remove skip once https://github.com/materialsvirtuallab/matgl/issues/238 is resolved")
def test_m3gnet(self):
pytest.importorskip("matgl")
enum_trans = EnumerateStructureTransformation(refine_structure=True, sort_criteria="m3gnet_relax")
Expand All @@ -204,6 +205,7 @@ def test_m3gnet(self):
# Check ordering of energy/atom
assert alls[0]["energy"] / alls[0]["num_sites"] <= alls[-1]["energy"] / alls[-1]["num_sites"]

@pytest.mark.skip("TODO remove skip once https://github.com/materialsvirtuallab/matgl/issues/238 is resolved")
def test_callable_sort_criteria(self):
matgl = pytest.importorskip("matgl")
from matgl.ext.ase import Relaxer
Expand All @@ -212,8 +214,8 @@ def test_callable_sort_criteria(self):

m3gnet_model = Relaxer(potential=pot)

def sort_criteria(s):
relax_results = m3gnet_model.relax(s)
def sort_criteria(struct: Structure) -> tuple[Structure, float]:
relax_results = m3gnet_model.relax(struct)
energy = float(relax_results["trajectory"].energies[-1])
return relax_results["final_structure"], energy

Expand Down

0 comments on commit 4fcbc58

Please sign in to comment.