Skip to content

Commit

Permalink
Return self in SiteCollection spin/oxi state add/remove methods (#…
Browse files Browse the repository at this point in the history
…3573)

* return self in SiteCollection.remove_oxidation_states()

* test_remove_oxidation_states() assert struct_out is struct_specie

refactor existing tests to use new return value

* same for add_oxidation_state_by_guess(), add_spin_by_element(), add_spin_by_site(), remove_spin()

* add tests

* doc str tweaks
  • Loading branch information
janosh authored Jan 23, 2024
1 parent f9e2830 commit b0e4eb2
Show file tree
Hide file tree
Showing 15 changed files with 77 additions and 73 deletions.
2 changes: 1 addition & 1 deletion pymatgen/analysis/local_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class ValenceIonicRadiusEvaluator:
def __init__(self, structure: Structure) -> None:
"""
Args:
structure: pymatgen.core.structure.Structure.
structure: pymatgen Structure.
"""
self._structure = structure.copy()
self._valences = self._get_valences()
Expand Down
3 changes: 1 addition & 2 deletions pymatgen/analysis/structure_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,8 +518,7 @@ def sulfide_type(structure):
Returns:
(str) sulfide/polysulfide or None if structure is a sulfate.
"""
structure = structure.copy()
structure.remove_oxidation_states()
structure = structure.copy().remove_oxidation_states()
sulphur = Element("S")
comp = structure.composition
if comp.is_element or sulphur not in comp:
Expand Down
16 changes: 8 additions & 8 deletions pymatgen/command_line/gulp_caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def specie_potential_lines(structure, potential, **kwargs):
structure.
Args:
structure: pymatgen.core.structure.Structure object
structure: pymatgen Structure object
potential: String specifying the type of potential used
kwargs: Additional parameters related to potential. For
potential == "buckingham",
Expand Down Expand Up @@ -375,7 +375,7 @@ def buckingham_input(self, structure: Structure, keywords, library=None, uc=True
from library.
Args:
structure: pymatgen.core.structure.Structure
structure: pymatgen Structure
keywords: GULP first line keywords.
library (Default=None): File containing the species and potential.
uc (Default=True): Unit Cell Flag.
Expand All @@ -401,7 +401,7 @@ def buckingham_potential(structure, val_dict=None):
J. Mater Chem., 4, 831-837 (1994)
Args:
structure: pymatgen.core.structure.Structure
structure: pymatgen Structure
val_dict (Needed if structure is not charge neutral): {El:valence}
dict, where El is element.
"""
Expand Down Expand Up @@ -462,7 +462,7 @@ def tersoff_input(self, structure: Structure, periodic=False, uc=True, *keywords
"""Gets a GULP input with Tersoff potential for an oxide structure.
Args:
structure: pymatgen.core.structure.Structure
structure: pymatgen Structure
periodic (Default=False): Flag denoting whether periodic
boundary conditions are used
library (Default=None): File containing the species and potential.
Expand All @@ -487,7 +487,7 @@ def tersoff_potential(structure):
"""Generate the species, Tersoff potential lines for an oxide structure.
Args:
structure: pymatgen.core.structure.Structure
structure: pymatgen Structure
"""
bv = BVAnalyzer()
el = [site.specie.symbol for site in structure]
Expand Down Expand Up @@ -702,7 +702,7 @@ def get_energy_tersoff(structure, gulp_cmd="gulp"):
"""Compute the energy of a structure using Tersoff potential.
Args:
structure: pymatgen.core.structure.Structure
structure: pymatgen Structure
gulp_cmd: GULP command if not in standard place
"""
gio = GulpIO()
Expand All @@ -716,7 +716,7 @@ def get_energy_buckingham(structure, gulp_cmd="gulp", keywords=("optimise", "con
"""Compute the energy of a structure using Buckingham potential.
Args:
structure: pymatgen.core.structure.Structure
structure: pymatgen Structure
gulp_cmd: GULP command if not in standard place
keywords: GULP first line keywords
valence_dict: {El: valence}. Needed if the structure is not charge
Expand All @@ -733,7 +733,7 @@ def get_energy_relax_structure_buckingham(structure, gulp_cmd="gulp", keywords=(
"""Relax a structure and compute the energy using Buckingham potential.
Args:
structure: pymatgen.core.structure.Structure
structure: pymatgen Structure
gulp_cmd: GULP command if not in standard place
keywords: GULP first line keywords
valence_dict: {El: valence}. Needed if the structure is not charge
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

class Interface(Structure):
"""This class stores data for defining an interface between two structures.
It is a subclass of pymatgen.core.structure.Structure.
It is a subclass of pymatgen Structure.
"""

def __init__(
Expand Down
20 changes: 15 additions & 5 deletions pymatgen/core/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ def add_oxidation_state_by_site(self, oxidation_states: list[float]) -> None:
new_sp[Species(sym, ox)] = occu
site.species = Composition(new_sp)

def remove_oxidation_states(self) -> None:
def remove_oxidation_states(self) -> SiteCollection:
"""Removes oxidation states from a structure."""
for site in self:
new_sp: dict[Element, float] = collections.defaultdict(float)
Expand All @@ -587,7 +587,9 @@ def remove_oxidation_states(self) -> None:
new_sp[Element(sym)] += occu
site.species = Composition(new_sp)

def add_oxidation_state_by_guess(self, **kwargs) -> None:
return self

def add_oxidation_state_by_guess(self, **kwargs) -> SiteCollection:
"""Decorates the structure with oxidation state, guessing
using Composition.oxi_state_guesses(). If multiple guesses are found
we take the first one.
Expand All @@ -599,7 +601,9 @@ def add_oxidation_state_by_guess(self, **kwargs) -> None:
oxi_guess = oxi_guess or [{e.symbol: 0 for e in self.composition}]
self.add_oxidation_state_by_element(oxi_guess[0])

def add_spin_by_element(self, spins: dict[str, float]) -> None:
return self

def add_spin_by_element(self, spins: dict[str, float]) -> SiteCollection:
"""Add spin states to structure.
Args:
Expand All @@ -615,7 +619,9 @@ def add_spin_by_element(self, spins: dict[str, float]) -> None:
new_species[species] = occu
site.species = Composition(new_species)

def add_spin_by_site(self, spins: Sequence[float]) -> None:
return self

def add_spin_by_site(self, spins: Sequence[float]) -> SiteCollection:
"""Add spin states to structure by site.
Args:
Expand All @@ -632,7 +638,9 @@ def add_spin_by_site(self, spins: Sequence[float]) -> None:
new_species[Species(sym, oxidation_state=oxi_state, spin=spin)] = occu
site.species = Composition(new_species)

def remove_spin(self) -> None:
return self

def remove_spin(self) -> SiteCollection:
"""Remove spin states from structure."""
for site in self:
new_sp: dict[Element, float] = collections.defaultdict(float)
Expand All @@ -641,6 +649,8 @@ def remove_spin(self) -> None:
new_sp[Species(sp.symbol, oxidation_state=oxi_state)] += occu
site.species = Composition(new_sp)

return self

def extract_cluster(self, target_sites: list[Site], **kwargs) -> list[Site]:
"""Extracts a cluster of atoms based on bond lengths.
Expand Down
10 changes: 4 additions & 6 deletions pymatgen/ext/cod.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def query(self, sql: str) -> str:
Returns:
Response from SQL query.
"""
r = subprocess.check_output(["mysql", "-u", "cod_reader", "-h", self.url, "-e", sql, "cod"])
return r.decode("utf-8")
resp = subprocess.check_output(["mysql", "-u", "cod_reader", "-h", self.url, "-e", sql, "cod"])
return resp.decode("utf-8")

@requires(which("mysql"), "mysql must be installed to use this query.")
def get_cod_ids(self, formula):
Expand Down Expand Up @@ -84,8 +84,7 @@ def get_structure_by_id(self, cod_id, **kwargs):
Args:
cod_id (int): COD id.
kwargs: All kwargs supported by
:func:`pymatgen.core.structure.Structure.from_str`.
kwargs: All kwargs supported by Structure.from_str.
Returns:
A Structure.
Expand All @@ -100,8 +99,7 @@ def get_structure_by_formula(self, formula: str, **kwargs) -> list[dict[str, str
Args:
formula (str): Chemical formula.
kwargs: All kwargs supported by
:func:`pymatgen.core.structure.Structure.from_str`.
kwargs: All kwargs supported by Structure.from_str.
Returns:
A list of dict of the format [{"structure": Structure, "cod_id": int, "sg": "P n m a"}]
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/io/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def get_structure(atoms: Atoms, cls: type[Structure] = Structure, **cls_kwargs)
**cls_kwargs: Any additional kwargs to pass to the cls
Returns:
Equivalent pymatgen.core.structure.Structure
Structure: Equivalent pymatgen Structure
"""
symbols = atoms.get_chemical_symbols()
positions = atoms.get_positions()
Expand Down
20 changes: 8 additions & 12 deletions pymatgen/io/cp2k/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def atoi(t):
return [atoi(c) for c in re.split(r"_(\d+)", text)]


def get_unique_site_indices(structure: Structure | Molecule):
def get_unique_site_indices(struct: Structure | Molecule):
"""
Get unique site indices for a structure according to site properties. Whatever site-property
has the most unique values is used for indexing.
Expand All @@ -147,31 +147,27 @@ def get_unique_site_indices(structure: Structure | Molecule):
"aux_basis",
}

for site in structure:
for site in struct:
for sp in site.species:
oxi_states.append(getattr(sp, "oxi_state", 0))
spins.append(getattr(sp, "_properties", {}).get("spin", 0))

structure.add_site_property("oxi_state", oxi_states)
structure.add_site_property("spin", spins)
structure.remove_oxidation_states()
struct.add_site_property("oxi_state", oxi_states)
struct.add_site_property("spin", spins)
struct.remove_oxidation_states()
items = [
(
site.species_string,
*[
structure.site_properties[k][i]
for k in structure.site_properties
if k.lower() in parsable_site_properties
],
*[struct.site_properties[k][i] for k in struct.site_properties if k.lower() in parsable_site_properties],
)
for i, site in enumerate(structure)
for i, site in enumerate(struct)
]
unique_itms = list(set(items))
_sites: dict[tuple, list] = {u: [] for u in unique_itms}
for i, itm in enumerate(items):
_sites[itm].append(i)
sites = {}
nums = {s: 1 for s in structure.symbol_set}
nums = {s: 1 for s in struct.symbol_set}
for s in _sites:
sites[f"{s[0]}_{nums[s[0]]}"] = _sites[s]
nums[s[0]] += 1
Expand Down
4 changes: 2 additions & 2 deletions pymatgen/io/jarvis.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_atoms(structure):
Returns JARVIS Atoms object from pymatgen structure.
Args:
structure: pymatgen.core.structure.Structure
structure: pymatgen Structure
Returns:
JARVIS Atoms object
Expand All @@ -49,7 +49,7 @@ def get_structure(atoms):
atoms: JARVIS Atoms object
Returns:
Equivalent pymatgen.core.structure.Structure
Equivalent pymatgen Structure
"""
return Structure(
lattice=atoms.lattice_mat, species=atoms.elements, coords=atoms.frac_coords, coords_are_cartesian=False
Expand Down
21 changes: 10 additions & 11 deletions pymatgen/io/zeopp.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def get_voronoi_nodes(structure, rad_dict=None, probe_rad=0.1):
Calls Zeo++ for Voronoi decomposition.
Args:
structure: pymatgen.core.structure.Structure
structure: pymatgen Structure
rad_dict (optional): Dictionary of radii of elements in structure.
If not given, Zeo++ default values are used.
Note: Zeo++ uses atomic radii of elements.
Expand All @@ -228,10 +228,9 @@ def get_voronoi_nodes(structure, rad_dict=None, probe_rad=0.1):
0.1 A
Returns:
voronoi nodes as pymatgen.core.structure.Structure within the
unit cell defined by the lattice of input structure
voronoi face centers as pymatgen.core.structure.Structure within the
unit cell defined by the lattice of input structure
voronoi nodes as pymatgen Structure within the unit cell defined by the lattice of
input structure voronoi face centers as pymatgen Structure within the unit cell
defined by the lattice of input structure
"""
with ScratchDir("."):
name = "temp_zeo1"
Expand Down Expand Up @@ -306,7 +305,7 @@ def get_high_accuracy_voronoi_nodes(structure, rad_dict, probe_rad=0.1):
Calls Zeo++ for Voronoi decomposition.
Args:
structure: pymatgen.core.structure.Structure
structure: pymatgen Structure
rad_dict (optional): Dictionary of radii of elements in structure.
If not given, Zeo++ default values are used.
Note: Zeo++ uses atomic radii of elements.
Expand All @@ -315,9 +314,9 @@ def get_high_accuracy_voronoi_nodes(structure, rad_dict, probe_rad=0.1):
Default is 0.1 A
Returns:
voronoi nodes as pymatgen.core.structure.Structure within the
voronoi nodes as pymatgen Structure within the
unit cell defined by the lattice of input structure
voronoi face centers as pymatgen.core.structure.Structure within the
voronoi face centers as pymatgen Structure within the
unit cell defined by the lattice of input structure
"""
with ScratchDir("."):
Expand Down Expand Up @@ -368,7 +367,7 @@ def get_free_sphere_params(structure, rad_dict=None, probe_rad=0.1):
Calls Zeo++ for Voronoi decomposition.
Args:
structure: pymatgen.core.structure.Structure
structure: pymatgen Structure
rad_dict (optional): Dictionary of radii of elements in structure.
If not given, Zeo++ default values are used.
Note: Zeo++ uses atomic radii of elements.
Expand All @@ -377,9 +376,9 @@ def get_free_sphere_params(structure, rad_dict=None, probe_rad=0.1):
0.1 A
Returns:
voronoi nodes as pymatgen.core.structure.Structure within the
voronoi nodes as pymatgen Structure within the
unit cell defined by the lattice of input structure
voronoi face centers as pymatgen.core.structure.Structure within the
voronoi face centers as pymatgen Structure within the
unit cell defined by the lattice of input structure
"""
with ScratchDir("."):
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/symmetry/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _get_symmetry_dataset(cell, symprec, angle_tolerance):


class SpacegroupAnalyzer:
"""Takes a pymatgen.core.structure.Structure object and a symprec.
"""Takes a pymatgen Structure object and a symprec.
Uses spglib to perform various symmetry finding operations.
"""
Expand Down
14 changes: 6 additions & 8 deletions pymatgen/transformations/standard_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,7 @@ def apply_transformation(self, structure):
Returns:
Non-oxidation state decorated Structure.
"""
struct = structure.copy()
struct.remove_oxidation_states()
return struct
return structure.copy().remove_oxidation_states()

@property
def inverse(self):
Expand Down Expand Up @@ -611,25 +609,25 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
num_atoms = sum(structure.composition.values())

for output in ewald_m.output_lists:
s_copy = struct.copy()
struct_copy = struct.copy()
# do deletions afterwards because they screw up the indices of the
# structure
del_indices = []
for manipulation in output[1]:
if manipulation[1] is None:
del_indices.append(manipulation[0])
else:
s_copy[manipulation[0]] = manipulation[1]
s_copy.remove_sites(del_indices)
struct_copy[manipulation[0]] = manipulation[1]
struct_copy.remove_sites(del_indices)

if self.no_oxi_states:
s_copy.remove_oxidation_states()
struct_copy.remove_oxidation_states()

self._all_structures.append(
{
"energy": output[0],
"energy_above_minimum": (output[0] - lowest_energy) / num_atoms,
"structure": s_copy.get_sorted_structure(),
"structure": struct_copy.get_sorted_structure(),
}
)

Expand Down
7 changes: 3 additions & 4 deletions tests/analysis/test_local_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,10 +1226,9 @@ def test_weighted_cn_no_oxid(self):
3.3897, 3.2589, 3.1207, 3.1924, 3.1915, 3.1207, 3.2598, 3.3897,
]
# fmt: on
s = self.lifepo4.copy()
s.remove_oxidation_states()
for idx in range(len(s)):
cn_array.append(cnn.get_cn(s, idx, use_weights=True))
struct = self.lifepo4.copy().remove_oxidation_states()
for idx in range(len(struct)):
cn_array.append(cnn.get_cn(struct, idx, use_weights=True))

assert_allclose(expected_array, cn_array, 2)

Expand Down
Loading

0 comments on commit b0e4eb2

Please sign in to comment.