Skip to content

Commit

Permalink
test_remove_oxidation_states() assert struct_out is struct_specie
Browse files Browse the repository at this point in the history
refactor existing tests to use new return value
  • Loading branch information
janosh committed Jan 23, 2024
1 parent b06a0ad commit 185711c
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 32 deletions.
3 changes: 1 addition & 2 deletions pymatgen/analysis/structure_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,8 +518,7 @@ def sulfide_type(structure):
Returns:
(str) sulfide/polysulfide or None if structure is a sulfate.
"""
structure = structure.copy()
structure.remove_oxidation_states()
structure = structure.copy().remove_oxidation_states()
sulphur = Element("S")
comp = structure.composition
if comp.is_element or sulphur not in comp:
Expand Down
20 changes: 8 additions & 12 deletions pymatgen/io/cp2k/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def atoi(t):
return [atoi(c) for c in re.split(r"_(\d+)", text)]


def get_unique_site_indices(structure: Structure | Molecule):
def get_unique_site_indices(struct: Structure | Molecule):
"""
Get unique site indices for a structure according to site properties. Whatever site-property
has the most unique values is used for indexing.
Expand All @@ -147,31 +147,27 @@ def get_unique_site_indices(structure: Structure | Molecule):
"aux_basis",
}

for site in structure:
for site in struct:
for sp in site.species:
oxi_states.append(getattr(sp, "oxi_state", 0))
spins.append(getattr(sp, "_properties", {}).get("spin", 0))

structure.add_site_property("oxi_state", oxi_states)
structure.add_site_property("spin", spins)
structure.remove_oxidation_states()
struct.add_site_property("oxi_state", oxi_states)
struct.add_site_property("spin", spins)
struct.remove_oxidation_states()
items = [
(
site.species_string,
*[
structure.site_properties[k][i]
for k in structure.site_properties
if k.lower() in parsable_site_properties
],
*[struct.site_properties[k][i] for k in struct.site_properties if k.lower() in parsable_site_properties],
)
for i, site in enumerate(structure)
for i, site in enumerate(struct)
]
unique_itms = list(set(items))
_sites: dict[tuple, list] = {u: [] for u in unique_itms}
for i, itm in enumerate(items):
_sites[itm].append(i)
sites = {}
nums = {s: 1 for s in structure.symbol_set}
nums = {s: 1 for s in struct.symbol_set}
for s in _sites:
sites[f"{s[0]}_{nums[s[0]]}"] = _sites[s]
nums[s[0]] += 1
Expand Down
14 changes: 6 additions & 8 deletions pymatgen/transformations/standard_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,7 @@ def apply_transformation(self, structure):
Returns:
Non-oxidation state decorated Structure.
"""
struct = structure.copy()
struct.remove_oxidation_states()
return struct
return structure.copy().remove_oxidation_states()

@property
def inverse(self):
Expand Down Expand Up @@ -611,25 +609,25 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
num_atoms = sum(structure.composition.values())

for output in ewald_m.output_lists:
s_copy = struct.copy()
struct_copy = struct.copy()
# do deletions afterwards because they screw up the indices of the
# structure
del_indices = []
for manipulation in output[1]:
if manipulation[1] is None:
del_indices.append(manipulation[0])
else:
s_copy[manipulation[0]] = manipulation[1]
s_copy.remove_sites(del_indices)
struct_copy[manipulation[0]] = manipulation[1]
struct_copy.remove_sites(del_indices)

if self.no_oxi_states:
s_copy.remove_oxidation_states()
struct_copy.remove_oxidation_states()

self._all_structures.append(
{
"energy": output[0],
"energy_above_minimum": (output[0] - lowest_energy) / num_atoms,
"structure": s_copy.get_sorted_structure(),
"structure": struct_copy.get_sorted_structure(),
}
)

Expand Down
7 changes: 3 additions & 4 deletions tests/analysis/test_local_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,10 +1226,9 @@ def test_weighted_cn_no_oxid(self):
3.3897, 3.2589, 3.1207, 3.1924, 3.1915, 3.1207, 3.2598, 3.3897,
]
# fmt: on
s = self.lifepo4.copy()
s.remove_oxidation_states()
for idx in range(len(s)):
cn_array.append(cnn.get_cn(s, idx, use_weights=True))
struct = self.lifepo4.copy().remove_oxidation_states()
for idx in range(len(struct)):
cn_array.append(cnn.get_cn(struct, idx, use_weights=True))

assert_allclose(expected_array, cn_array, 2)

Expand Down
9 changes: 5 additions & 4 deletions tests/core/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,10 +1099,11 @@ def test_remove_oxidation_states(self):
o_specie = Species("O", -2)
coords = [[0, 0, 0], [0.75, 0.5, 0.75]]
lattice = Lattice.cubic(10)
s_elem = Structure(lattice, [co_elem, o_elem], coords)
s_specie = Structure(lattice, [co_specie, o_specie], coords)
s_specie.remove_oxidation_states()
assert s_elem == s_specie, "Oxidation state remover failed"
struct_elem = Structure(lattice, [co_elem, o_elem], coords)
struct_specie = Structure(lattice, [co_specie, o_specie], coords)
struct_out = struct_specie.remove_oxidation_states()
assert struct_out is struct_specie
assert struct_elem == struct_specie, "Oxidation state remover failed"

def test_add_oxidation_states_by_guess(self):
struct = PymatgenTest.get_structure("Li2O")
Expand Down
3 changes: 1 addition & 2 deletions tests/transformations/test_advanced_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,7 @@ def setUp(self):
trans = AutoOxiStateDecorationTransformation()
self.Fe3O4_oxi = trans.apply_transformation(self.Fe3O4)

self.Li8Fe2NiCoO8 = Structure.from_file(f"{TEST_FILES_DIR}/Li8Fe2NiCoO8.cif")
self.Li8Fe2NiCoO8.remove_oxidation_states()
self.Li8Fe2NiCoO8 = Structure.from_file(f"{TEST_FILES_DIR}/Li8Fe2NiCoO8.cif").remove_oxidation_states()

def test_apply_transformation(self):
trans = MagOrderingTransformation({"Fe": 5})
Expand Down

0 comments on commit 185711c

Please sign in to comment.