Skip to content

Commit

Permalink
skip failing test_callable_sort_criteria test with TODO and link to m…
Browse files Browse the repository at this point in the history
…atgl issue
  • Loading branch information
janosh committed Mar 7, 2024
1 parent 6a9d3c5 commit d8fdb09
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
20 changes: 10 additions & 10 deletions pymatgen/transformations/advanced_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,9 @@ def __init__(
if max_cell_size and max_disordered_sites:
raise ValueError("Cannot set both max_cell_size and max_disordered_sites!")

def apply_transformation(self, structure: Structure, return_ranked_list: bool | int = False):
def apply_transformation(
self, structure: Structure, return_ranked_list: bool | int = False
) -> Structure | list[dict]:
"""Returns either a single ordered structure or a sequence of all ordered
structures.
Expand Down Expand Up @@ -879,7 +881,7 @@ def apply_transformation(
# remove dummy species and replace Spin.up or Spin.down
# with spin magnitudes given in mag_species_spin arg
alls = self._remove_dummy_species(alls)
alls = self._add_spin_magnitudes(alls)
alls = self._add_spin_magnitudes(alls) # type: ignore[arg-type]
else:
for idx in range(len(alls)):
alls[idx]["structure"] = self._remove_dummy_species(alls[idx]["structure"])
Expand All @@ -891,7 +893,7 @@ def apply_transformation(
num_to_return = 1

if num_to_return == 1 or not return_ranked_list:
return alls[0]["structure"] if num_to_return else alls
return alls[0]["structure"] if num_to_return else alls # type: ignore[return-value]

# remove duplicate structures and group according to energy model
matcher = StructureMatcher(comparator=SpinComparator())
Expand Down Expand Up @@ -1010,11 +1012,10 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
Args:
structure (Structure): Input structure to dope
return_ranked_list (bool | int, optional): If return_ranked_list is int, that number of structures.
is returned. If False, only the single lowest energy structure is returned. Defaults to False.
Returns:
[{"structure": Structure, "energy": float}]
list[dict] | Structure: each dict has shape {"structure": Structure, "energy": float}.
"""
comp = structure.composition
logger.info(f"Composition: {comp}")
Expand Down Expand Up @@ -1059,7 +1060,7 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
logger.info(f"{lengths=}")
logger.info(f"{scaling=}")

all_structures = []
all_structures: list[dict] = []
trafo = EnumerateStructureTransformation(**self.kwargs)

for sp in compatible_species:
Expand Down Expand Up @@ -1131,10 +1132,9 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
}
)

ss = trafo.apply_transformation(supercell, return_ranked_list=self.max_structures_per_enum)
logger.info(f"{len(ss)} distinct structures")
all_structures.extend(ss)

structs = trafo.apply_transformation(supercell, return_ranked_list=self.max_structures_per_enum)
logger.info(f"{len(structs)} distinct structures")
all_structures.extend(structs)
logger.info(f"Total {len(all_structures)} doped structures")
if return_ranked_list:
return all_structures[:return_ranked_list]
Expand Down
5 changes: 3 additions & 2 deletions tests/transformations/test_advanced_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def test_m3gnet(self):
# Check ordering of energy/atom
assert alls[0]["energy"] / alls[0]["num_sites"] <= alls[-1]["energy"] / alls[-1]["num_sites"]

@pytest.skip("TODO remove skip once https://github.com/materialsvirtuallab/matgl/issues/238 is resolved")
def test_callable_sort_criteria(self):
matgl = pytest.importorskip("matgl")
from matgl.ext.ase import Relaxer
Expand All @@ -212,8 +213,8 @@ def test_callable_sort_criteria(self):

m3gnet_model = Relaxer(potential=pot)

def sort_criteria(s):
relax_results = m3gnet_model.relax(s)
def sort_criteria(struct: Structure) -> tuple[Structure, float]:
relax_results = m3gnet_model.relax(struct)
energy = float(relax_results["trajectory"].energies[-1])
return relax_results["final_structure"], energy

Expand Down

0 comments on commit d8fdb09

Please sign in to comment.