diff --git a/pymatgen/analysis/chemenv/coordination_environments/coordination_geometry_finder.py b/pymatgen/analysis/chemenv/coordination_environments/coordination_geometry_finder.py index 47833b7b459..f202bec079a 100644 --- a/pymatgen/analysis/chemenv/coordination_environments/coordination_geometry_finder.py +++ b/pymatgen/analysis/chemenv/coordination_environments/coordination_geometry_finder.py @@ -218,7 +218,7 @@ def points_wcs_csc(self, permutation=None): """ if permutation is None: return self._points_wcs_csc - return np.concatenate((self._points_wcs_csc[0:1], self._points_wocs_csc.take(permutation, axis=0))) + return np.concatenate((self._points_wcs_csc[:1], self._points_wocs_csc.take(permutation, axis=0))) def points_wocs_csc(self, permutation=None): """ @@ -238,7 +238,7 @@ def points_wcs_ctwcc(self, permutation=None): return self._points_wcs_ctwcc return np.concatenate( ( - self._points_wcs_ctwcc[0:1], + self._points_wcs_ctwcc[:1], self._points_wocs_ctwcc.take(permutation, axis=0), ) ) @@ -261,7 +261,7 @@ def points_wcs_ctwocc(self, permutation=None): return self._points_wcs_ctwocc return np.concatenate( ( - self._points_wcs_ctwocc[0:1], + self._points_wcs_ctwocc[:1], self._points_wocs_ctwocc.take(permutation, axis=0), ) ) diff --git a/pymatgen/analysis/chemenv/utils/coordination_geometry_utils.py b/pymatgen/analysis/chemenv/utils/coordination_geometry_utils.py index 51a47a7bdd8..edddcf9166e 100644 --- a/pymatgen/analysis/chemenv/utils/coordination_geometry_utils.py +++ b/pymatgen/analysis/chemenv/utils/coordination_geometry_utils.py @@ -901,7 +901,7 @@ def project_and_to2dim(self, pps, plane_center): xypps = [] for pp in proj: xyzpp = np.dot(pp, PP) - xypps.append(xyzpp[0:2]) + xypps.append(xyzpp[:2]) if str(plane_center) == "mean": mean = np.zeros(2, float) for pp in xypps: @@ -910,7 +910,7 @@ def project_and_to2dim(self, pps, plane_center): xypps = [pp - mean for pp in xypps] elif plane_center is not None: projected_plane_center = self.projectionpoints([plane_center])[0] - xy_projected_plane_center = np.dot(projected_plane_center, PP)[0:2] + xy_projected_plane_center = np.dot(projected_plane_center, PP)[:2] xypps = [pp - xy_projected_plane_center for pp in xypps] return xypps @@ -960,7 +960,7 @@ def coefficients(self): @property def abcd(self): """A tuple with the plane coefficients.""" - return tuple(self._coefficients[0:4]) + return tuple(self._coefficients[:4]) @property def a(self): diff --git a/pymatgen/analysis/ferroelectricity/polarization.py b/pymatgen/analysis/ferroelectricity/polarization.py index 4f934495be9..08df27786b6 100644 --- a/pymatgen/analysis/ferroelectricity/polarization.py +++ b/pymatgen/analysis/ferroelectricity/polarization.py @@ -130,7 +130,7 @@ def get_nearest_site(struct: Structure, coords: Sequence[float], site: PeriodicS # Sort by distance to coords ns.sort(key=lambda x: x[1]) # Return PeriodicSite and distance of closest image - return ns[0][0:2] + return ns[0][:2] class Polarization: diff --git a/pymatgen/analysis/magnetism/analyzer.py b/pymatgen/analysis/magnetism/analyzer.py index 5ee0fe417f6..e05dfe07a2d 100644 --- a/pymatgen/analysis/magnetism/analyzer.py +++ b/pymatgen/analysis/magnetism/analyzer.py @@ -1011,7 +1011,7 @@ def _add_structures(ordered_structures, ordered_structures_origins, structures_t # ...and decide which ones to keep if len(max_symmetries) > self.truncate_by_symmetry: - max_symmetries = max_symmetries[0:5] + max_symmetries = max_symmetries[:5] structs_to_keep = [(idx, num) for idx, num in enumerate(num_sym_ops) if num in max_symmetries] # sort so that highest symmetry structs are first diff --git a/pymatgen/command_line/gulp_caller.py b/pymatgen/command_line/gulp_caller.py index 9c43b671dc9..452631492b0 100644 --- a/pymatgen/command_line/gulp_caller.py +++ b/pymatgen/command_line/gulp_caller.py @@ -585,7 +585,7 @@ def get_relaxed_structure(gout: str): # read the site coordinates in the following lines idx += 6 line = output_lines[idx] - while line[0:2] != "--": + while line[:2] != "--": structure_lines.append(line) idx += 1 line = output_lines[idx] diff --git a/pymatgen/command_line/mcsqs_caller.py b/pymatgen/command_line/mcsqs_caller.py index 40be0b12789..657060d6dec 100644 --- a/pymatgen/command_line/mcsqs_caller.py +++ b/pymatgen/command_line/mcsqs_caller.py @@ -261,7 +261,7 @@ def _parse_clusters(filename): for point in range(cluster_dict["num_points_in_cluster"]): line = cluster[3 + point].split(" ") point_dict = {} - point_dict["coordinates"] = [float(line) for line in line[0:3]] + point_dict["coordinates"] = [float(line) for line in line[:3]] point_dict["num_possible_species"] = int(line[3]) + 2 # see ATAT manual for why +2 point_dict["cluster_function"] = float(line[4]) # see ATAT manual for what "function" is points.append(point_dict) diff --git a/pymatgen/core/composition.py b/pymatgen/core/composition.py index 531cbb3de5e..5c589936a92 100644 --- a/pymatgen/core/composition.py +++ b/pymatgen/core/composition.py @@ -181,7 +181,7 @@ def __eq__(self, other: object) -> bool: Args: other: Composition to compare to. """ - if not isinstance(other, (Composition, dict)): + if not isinstance(other, (type(self), dict)): return NotImplemented # elements with amounts < Composition.amount_tolerance don't show up @@ -190,7 +190,7 @@ def __eq__(self, other: object) -> bool: if len(self) != len(other): return False - return all(abs(amt - other[el]) <= Composition.amount_tolerance for el, amt in self.items()) + return all(abs(amt - other[el]) <= type(self).amount_tolerance for el, amt in self.items()) def __ge__(self, other: object) -> bool: """Composition greater than or equal to. We consider compositions A >= B @@ -200,31 +200,31 @@ def __ge__(self, other: object) -> bool: Should ONLY be used for defining a sort order (the behavior is probably not what you'd expect). """ - if not isinstance(other, Composition): + if not isinstance(other, type(self)): return NotImplemented for el in sorted(set(self.elements + other.elements)): - if other[el] - self[el] >= Composition.amount_tolerance: + if other[el] - self[el] >= type(self).amount_tolerance: return False # TODO @janosh 2024-04-29: is this a bug? why would we return True early? - if self[el] - other[el] >= Composition.amount_tolerance: + if self[el] - other[el] >= type(self).amount_tolerance: return True return True - def __add__(self, other: object) -> Composition: + def __add__(self, other: object) -> Self: """Add two compositions. For example, an Fe2O3 composition + an FeO composition gives a Fe3O4 composition. """ - if not isinstance(other, (Composition, dict)): + if not isinstance(other, (type(self), dict)): return NotImplemented new_el_map: dict[SpeciesLike, float] = defaultdict(float) new_el_map.update(self) for key, val in other.items(): new_el_map[get_el_sp(key)] += val - return Composition(new_el_map, allow_negative=self.allow_negative) + return type(self)(new_el_map, allow_negative=self.allow_negative) - def __sub__(self, other: object) -> Composition: + def __sub__(self, other: object) -> Self: """Subtracts two compositions. For example, an Fe2O3 composition - an FeO composition gives an FeO2 composition. @@ -233,29 +233,29 @@ def __sub__(self, other: object) -> Composition: original composition in any of its elements, unless allow_negative is True """ - if not isinstance(other, (Composition, dict)): + if not isinstance(other, (type(self), dict)): return NotImplemented new_el_map: dict[SpeciesLike, float] = defaultdict(float) new_el_map.update(self) for key, val in other.items(): new_el_map[get_el_sp(key)] -= val - return Composition(new_el_map, allow_negative=self.allow_negative) + return type(self)(new_el_map, allow_negative=self.allow_negative) - def __mul__(self, other: object) -> Composition: + def __mul__(self, other: object) -> Self: """Multiply a Composition by an integer or a float. Fe2O3 * 4 -> Fe8O12. """ if not isinstance(other, (int, float)): return NotImplemented - return Composition({el: self[el] * other for el in self}, allow_negative=self.allow_negative) + return type(self)({el: self[el] * other for el in self}, allow_negative=self.allow_negative) __rmul__ = __mul__ - def __truediv__(self, other: object) -> Composition: + def __truediv__(self, other: object) -> Self: if not isinstance(other, (int, float)): return NotImplemented - return Composition({el: self[el] / other for el in self}, allow_negative=self.allow_negative) + return type(self)({el: self[el] / other for el in self}, allow_negative=self.allow_negative) __div__ = __truediv__ @@ -273,7 +273,12 @@ def total_electrons(self) -> float: """Total number of electrons in composition.""" return sum((el.Z * abs(amt) for el, amt in self.items())) - def almost_equals(self, other: Composition, rtol: float = 0.1, atol: float = 1e-8) -> bool: + def almost_equals( + self, + other: Composition, + rtol: float = 0.1, + atol: float = 1e-8, + ) -> bool: """Get true if compositions are equal within a tolerance. Args: @@ -295,7 +300,7 @@ def is_element(self) -> bool: """True if composition is an element.""" return len(self) == 1 - def copy(self) -> Composition: + def copy(self) -> Self: """A copy of the composition.""" return Composition(self, allow_negative=self.allow_negative) @@ -332,12 +337,12 @@ def iupac_formula(self) -> str: return " ".join(formula) @property - def element_composition(self) -> Composition: + def element_composition(self) -> Self: """The composition replacing any species by the corresponding element.""" return Composition(self.get_el_amt_dict(), allow_negative=self.allow_negative) @property - def fractional_composition(self) -> Composition: + def fractional_composition(self) -> Self: """The normalized composition in which the amounts of each species sum to 1. E.g. "Fe2 O3".fractional_composition = "Fe0.4 O0.6". @@ -345,13 +350,13 @@ def fractional_composition(self) -> Composition: return self / self._n_atoms @property - def reduced_composition(self) -> Composition: + def reduced_composition(self) -> Self: """The reduced composition, i.e. amounts normalized by greatest common denominator. E.g. "Fe4 P4 O16".reduced_composition = "Fe P O4". """ return self.get_reduced_composition_and_factor()[0] - def get_reduced_composition_and_factor(self) -> tuple[Composition, float]: + def get_reduced_composition_and_factor(self) -> tuple[Self, float]: """Calculate a reduced composition and factor. Returns: @@ -378,14 +383,14 @@ def get_reduced_formula_and_factor(self, iupac_ordering: bool = False) -> tuple[ A pretty normalized formula and a multiplicative factor, i.e., Li4Fe4P4O16 returns (LiFePO4, 4). """ - all_int = all(abs(val - round(val)) < Composition.amount_tolerance for val in self.values()) + all_int = all(abs(val - round(val)) < type(self).amount_tolerance for val in self.values()) if not all_int: return self.formula.replace(" ", ""), 1 el_amt_dict = {key: int(round(val)) for key, val in self.get_el_amt_dict().items()} formula, factor = reduce_formula(el_amt_dict, iupac_ordering=iupac_ordering) - if formula in Composition.special_formulas: - formula = Composition.special_formulas[formula] + if formula in type(self).special_formulas: + formula = type(self).special_formulas[formula] factor /= 2 return formula, factor @@ -416,8 +421,8 @@ def get_integer_formula_and_factor( dct = {key: round(val / _gcd) for key, val in el_amt.items()} formula, factor = reduce_formula(dct, iupac_ordering=iupac_ordering) - if formula in Composition.special_formulas: - formula = Composition.special_formulas[formula] + if formula in type(self).special_formulas: + formula = type(self).special_formulas[formula] factor /= 2 return formula, factor * _gcd @@ -549,13 +554,13 @@ def _parse_formula(self, formula: str, strict: bool = True) -> dict[str, float]: def get_sym_dict(form: str, factor: float) -> dict[str, float]: sym_dict: dict[str, float] = defaultdict(float) - for m in re.finditer(r"([A-Z][a-z]*)\s*([-*\.e\d]*)", form): - el = m.group(1) + for match in re.finditer(r"([A-Z][a-z]*)\s*([-*\.e\d]*)", form): + el = match[1] amt = 1.0 - if m.group(2).strip() != "": - amt = float(m.group(2)) + if match[2].strip() != "": + amt = float(match[2]) sym_dict[el] += amt * factor - form = form.replace(m.group(), "", 1) + form = form.replace(match.group(), "", 1) if form.strip(): raise ValueError(f"{form} is an invalid formula!") return sym_dict @@ -563,9 +568,9 @@ def get_sym_dict(form: str, factor: float) -> dict[str, float]: match = re.search(r"\(([^\(\)]+)\)\s*([\.e\d]*)", formula) while match: factor = 1.0 - if match.group(2) != "": - factor = float(match.group(2)) - unit_sym_dict = get_sym_dict(match.group(1), factor) + if match[2] != "": + factor = float(match[2]) + unit_sym_dict = get_sym_dict(match[1], factor) expanded_sym = "".join(f"{el}{amt}" for el, amt in unit_sym_dict.items()) expanded_formula = formula.replace(match.group(), expanded_sym, 1) formula = expanded_formula @@ -734,7 +739,7 @@ def charge_balanced(self) -> bool | None: # to get a zero charge was found, so the composition is not charge balanced return False return None - return abs(self.charge) < Composition.charge_balanced_tolerance + return abs(self.charge) < type(self).charge_balanced_tolerance def oxi_state_guesses( self, @@ -778,7 +783,7 @@ def oxi_state_guesses( return ({self.elements[0].symbol: 0.0},) return self._get_oxi_state_guesses(all_oxi_states, max_sites, oxi_states_override, target_charge)[0] - def replace(self, elem_map: dict[str, str | dict[str, float]]) -> Composition: + def replace(self, elem_map: dict[str, str | dict[str, float]]) -> Self: """Replace elements in a composition. Returns a new Composition, leaving the old one unchanged. Args: @@ -827,7 +832,7 @@ def replace(self, elem_map: dict[str, str | dict[str, float]]) -> Composition: "This can be ambiguous, so be sure to check your result." ) - return Composition(new_comp) + return type(self)(new_comp) def add_charges_from_oxi_state_guesses( self, @@ -835,7 +840,7 @@ def add_charges_from_oxi_state_guesses( target_charge: float = 0, all_oxi_states: bool = False, max_sites: int | None = None, - ) -> Composition: + ) -> Self: """Assign oxidation states based on guessed oxidation states. See `oxi_state_guesses` for an explanation of how oxidation states are @@ -870,7 +875,7 @@ def add_charges_from_oxi_state_guesses( # Special case: No charged compound is possible if not oxidation_states: - return Composition({Species(e, 0): f for e, f in self.items()}) + return type(self)({Species(e, 0): f for e, f in self.items()}) # Generate the species species = [] @@ -878,9 +883,9 @@ def add_charges_from_oxi_state_guesses( species.extend([Species(el, c) for c in charges]) # Return the new object - return Composition(collections.Counter(species)) + return type(self)(collections.Counter(species)) - def remove_charges(self) -> Composition: + def remove_charges(self) -> Self: """Get a new Composition with charges from each Species removed. Returns: @@ -890,9 +895,15 @@ def remove_charges(self) -> Composition: dct: dict[Element, float] = defaultdict(float) for specie, amt in self.items(): dct[Element(specie.symbol)] += amt - return Composition(dct) + return type(self)(dct) - def _get_oxi_state_guesses(self, all_oxi_states, max_sites, oxi_states_override, target_charge): + def _get_oxi_state_guesses( + self, + all_oxi_states: bool, + max_sites: int | None, + oxi_states_override: dict[str, list] | None, + target_charge: float, + ) -> tuple[tuple, tuple]: """Utility operation for guessing oxidation states. See `oxi_state_guesses` for full details. This operation does the @@ -901,7 +912,7 @@ def _get_oxi_state_guesses(self, all_oxi_states, max_sites, oxi_states_override, Args: oxi_states_override (dict): dict of str->list to override an element's common oxidation states, e.g. {"V": [2,3,4,5]} - target_charge (int): the desired total charge on the structure. + target_charge (float): the desired total charge on the structure. Default is 0 signifying charge balance. all_oxi_states (bool): if True, an element defaults to all oxidation states in pymatgen Element.icsd_oxidation_states. @@ -941,7 +952,7 @@ def _get_oxi_state_guesses(self, all_oxi_states, max_sites, oxi_states_override, # Load prior probabilities of oxidation states, used to rank solutions if not Composition.oxi_prob: all_data = loadfn(f"{module_dir}/../analysis/icsd_bv.yaml") - Composition.oxi_prob = {Species.from_str(sp): data for sp, data in all_data["occurrence"].items()} + type(self).oxi_prob = {Species.from_str(sp): data for sp, data in all_data["occurrence"].items()} oxi_states_override = oxi_states_override or {} # assert: Composition only has integer amounts if not all(amt == int(amt) for amt in comp.values()): @@ -951,15 +962,15 @@ def _get_oxi_state_guesses(self, all_oxi_states, max_sites, oxi_states_override, # (taking into account nsites for that particular element) el_amt = comp.get_el_amt_dict() elements = list(el_amt) - el_sums = [] # matrix: dim1= el_idx, dim2=possible sums - el_sum_scores = defaultdict(set) # dict of el_idx, sum -> score - el_best_oxid_combo = {} # dict of el_idx, sum -> oxid combo with best score + el_sums: list = [] # matrix: dim1= el_idx, dim2=possible sums + el_sum_scores: defaultdict = defaultdict(set) # dict of el_idx, sum -> score + el_best_oxid_combo: dict = {} # dict of el_idx, sum -> oxid combo with best score for idx, el in enumerate(elements): el_sum_scores[idx] = {} el_best_oxid_combo[idx] = {} el_sums.append([]) if oxi_states_override.get(el): - oxids = oxi_states_override[el] + oxids: list | tuple = oxi_states_override[el] elif all_oxi_states: oxids = Element(el).oxidation_states else: @@ -974,7 +985,7 @@ def _get_oxi_state_guesses(self, all_oxi_states, max_sites, oxi_states_override, el_sums[idx].append(oxid_sum) # Determine how probable is this combo? - score = sum(Composition.oxi_prob.get(Species(el, o), 0) for o in oxid_combo) + score = sum(type(self).oxi_prob.get(Species(el, o), 0) for o in oxid_combo) # type: ignore[union-attr] # If it is the most probable combo for a certain sum, # store the combination @@ -983,7 +994,7 @@ def _get_oxi_state_guesses(self, all_oxi_states, max_sites, oxi_states_override, el_best_oxid_combo[idx][oxid_sum] = oxid_combo # Determine which combination of oxidation states for each element - # is the most probable + # is the most probable all_sols = [] # will contain all solutions all_oxid_combo = [] # will contain the best combination of oxidation states for each site all_scores = [] # will contain a score for each solution @@ -1199,7 +1210,10 @@ def _parse_chomp_and_rank(match, formula: str, m_dict: dict[str, float], m_point yield match -def reduce_formula(sym_amt: dict[str, float] | dict[str, int], iupac_ordering: bool = False) -> tuple[str, float]: +def reduce_formula( + sym_amt: dict[str, float] | dict[str, int], + iupac_ordering: bool = False, +) -> tuple[str, float]: """Helper function to reduce a sym_amt dict to a reduced formula and factor. Args: @@ -1263,30 +1277,30 @@ def __init__(self, *args, **kwargs) -> None: if len(dct) != len(self): raise ValueError("Duplicate potential specified") - def __mul__(self, other: object) -> ChemicalPotential: + def __mul__(self, other: object) -> Self: if isinstance(other, (int, float)): - return ChemicalPotential({key: val * other for key, val in self.items()}) + return type(self)({key: val * other for key, val in self.items()}) return NotImplemented __rmul__ = __mul__ - def __truediv__(self, other: object) -> ChemicalPotential: + def __truediv__(self, other: object) -> Self: if isinstance(other, (int, float)): - return ChemicalPotential({key: val / other for key, val in self.items()}) + return type(self)({key: val / other for key, val in self.items()}) return NotImplemented __div__ = __truediv__ - def __sub__(self, other: object) -> ChemicalPotential: - if isinstance(other, ChemicalPotential): + def __sub__(self, other: object) -> Self: + if isinstance(other, type(self)): els = {*self} | {*other} - return ChemicalPotential({e: self.get(e, 0) - other.get(e, 0) for e in els}) + return type(self)({e: self.get(e, 0) - other.get(e, 0) for e in els}) return NotImplemented - def __add__(self, other: object) -> ChemicalPotential: - if isinstance(other, ChemicalPotential): + def __add__(self, other: object) -> Self: + if isinstance(other, type(self)): els = {*self} | {*other} - return ChemicalPotential({e: self.get(e, 0) + other.get(e, 0) for e in els}) + return type(self)({e: self.get(e, 0) + other.get(e, 0) for e in els}) return NotImplemented def __repr__(self) -> str: diff --git a/pymatgen/core/ion.py b/pymatgen/core/ion.py index daf7efcaf28..a655b4cf8c3 100644 --- a/pymatgen/core/ion.py +++ b/pymatgen/core/ion.py @@ -99,12 +99,8 @@ def from_formula(cls, formula: str) -> Self: # If no brackets, parse trailing +/- for m_chg in re.finditer(r"([+-])([\.\d]*)", formula): - sign = m_chg.group(1) - sgn = float(f"{sign}1") - if m_chg.group(2).strip() != "": - charge += float(m_chg.group(2)) * sgn - else: - charge += sgn + sgn = float(f"{m_chg[1]}1") + charge += float(m_chg[2]) * sgn if m_chg[2].strip() != "" else sgn formula = formula.replace(m_chg.group(), "", 1) return cls(Composition(formula), charge) @@ -259,7 +255,7 @@ def from_dict(cls, dct: dict) -> Self: dct_copy = deepcopy(dct) charge = dct_copy.pop("charge") composition = Composition(dct_copy) - return Ion(composition, charge) + return cls(composition, charge) @property def to_reduced_dict(self) -> dict: @@ -312,7 +308,7 @@ def oxi_state_guesses( # type: ignore[override] oxidation state across all sites in that composition. If the composition is not charge balanced, an empty list is returned. """ - return self._get_oxi_state_guesses(all_oxi_states, max_sites, oxi_states_override, self.charge)[0] + return self._get_oxi_state_guesses(all_oxi_states, max_sites, oxi_states_override, self.charge)[0] # type: ignore[return-value] def to_pretty_string(self) -> str: """Pretty string with proper superscripts.""" diff --git a/pymatgen/core/lattice.py b/pymatgen/core/lattice.py index fb9c4019141..1449a3d5040 100644 --- a/pymatgen/core/lattice.py +++ b/pymatgen/core/lattice.py @@ -321,7 +321,14 @@ def orthorhombic( return cls.from_parameters(a, b, c, 90, 90, 90, pbc=pbc) @classmethod - def monoclinic(cls, a: float, b: float, c: float, beta: float, pbc: PbcLike = (True, True, True)) -> Self: + def monoclinic( + cls, + a: float, + b: float, + c: float, + beta: float, + pbc: PbcLike = (True, True, True), + ) -> Self: """Convenience constructor for a monoclinic lattice. Args: @@ -1040,8 +1047,8 @@ def _calculate_lll(self, delta: float = 0.75) -> tuple[np.ndarray, np.ndarray]: b[:, 0] = a[:, 0] m[0] = np.dot(b[:, 0], b[:, 0]) for i in range(1, 3): - u[i, 0:i] = np.dot(a[:, i].T, b[:, 0:i]) / m[0:i] - b[:, i] = a[:, i] - np.dot(b[:, 0:i], u[i, 0:i].T) + u[i, :i] = np.dot(a[:, i].T, b[:, :i]) / m[:i] + b[:, i] = a[:, i] - np.dot(b[:, :i], u[i, :i].T) m[i] = np.dot(b[:, i], b[:, i]) k = 2 @@ -1078,8 +1085,8 @@ def _calculate_lll(self, delta: float = 0.75) -> tuple[np.ndarray, np.ndarray]: # Update the Gram-Schmidt coefficients for s in range(k - 1, k + 1): - u[s - 1, 0 : (s - 1)] = np.dot(a[:, s - 1].T, b[:, 0 : (s - 1)]) / m[0 : (s - 1)] - b[:, s - 1] = a[:, s - 1] - np.dot(b[:, 0 : (s - 1)], u[s - 1, 0 : (s - 1)].T) + u[s - 1, : (s - 1)] = np.dot(a[:, s - 1].T, b[:, : (s - 1)]) / m[: (s - 1)] + b[:, s - 1] = a[:, s - 1] - np.dot(b[:, : (s - 1)], u[s - 1, : (s - 1)].T) m[s - 1] = np.dot(b[:, s - 1], b[:, s - 1]) if k > 2: @@ -1110,7 +1117,7 @@ def get_frac_coords_from_lll(self, lll_frac_coords: ArrayLike) -> np.ndarray: Doi("10.1107/S010876730302186X"), description="Numerically stable algorithms for the computation of reduced unit cells", ) - def get_niggli_reduced_lattice(self, tol: float = 1e-5) -> Lattice: + def get_niggli_reduced_lattice(self, tol: float = 1e-5) -> Self: """Get the Niggli reduced lattice using the numerically stable algo proposed by R. W. Grosse-Kunstleve, N. K. Sauter, & P. D. Adams, Acta Crystallographica Section A Foundations of Crystallography, 2003, @@ -1215,13 +1222,13 @@ def get_niggli_reduced_lattice(self, tol: float = 1e-5) -> Lattice: alpha = math.acos(E / 2 / b / c) / math.pi * 180 beta = math.acos(N / 2 / a / c) / math.pi * 180 gamma = math.acos(Y / 2 / a / b) / math.pi * 180 - lattice = Lattice.from_parameters(a, b, c, alpha, beta, gamma) + lattice = type(self).from_parameters(a, b, c, alpha, beta, gamma) mapped = self.find_mapping(lattice, e, skip_rotation_matrix=True) if mapped is not None: if np.linalg.det(mapped[0].matrix) > 0: return mapped[0] - return Lattice(-mapped[0].matrix) + return type(self)(-mapped[0].matrix) raise ValueError("can't find niggli") @@ -1824,7 +1831,7 @@ def get_points_in_spheres( valid_coords.append(coords[valid_index_bool]) valid_images.append(np.tile(image, [np.sum(valid_index_bool), 1]) - image_offsets[valid_index_bool]) valid_indices.extend([k for k in ind if valid_index_bool[k]]) - if len(valid_coords) < 1: + if not valid_coords: return [[]] * len(center_coords) valid_coords = np.concatenate(valid_coords, axis=0) valid_images = np.concatenate(valid_images, axis=0) diff --git a/pymatgen/core/molecular_orbitals.py b/pymatgen/core/molecular_orbitals.py index 2ad21b0b9dc..86649af3e2c 100644 --- a/pymatgen/core/molecular_orbitals.py +++ b/pymatgen/core/molecular_orbitals.py @@ -5,10 +5,14 @@ from __future__ import annotations from itertools import chain, combinations +from typing import TYPE_CHECKING from pymatgen.core import Element from pymatgen.core.composition import Composition +if TYPE_CHECKING: + from typing import Any + class MolecularOrbitals: """Represents the character of bands in a solid. The input is a chemical @@ -28,7 +32,7 @@ class MolecularOrbitals: # gives {'HOMO':['O','2p',-0.338381], 'LUMO':['Ti','3d',-0.17001], 'metal':False} """ - def __init__(self, formula) -> None: + def __init__(self, formula: str) -> None: """ Args: formula (str): Chemical formula. Must have integer subscripts. Ex: 'SrTiO3'. @@ -52,18 +56,18 @@ def __init__(self, formula) -> None: self.aos = {str(el): [[str(el), k, v] for k, v in Element(el).atomic_orbitals.items()] for el in self.elements} self.band_edges = self.obtain_band_edges() - def max_electronegativity(self): + def max_electronegativity(self) -> float: """ Returns: The maximum pairwise electronegativity difference. """ - maximum = 0 + maximum: float = 0.0 for e1, e2 in combinations(self.elements, 2): if abs(Element(e1).X - Element(e2).X) > maximum: maximum = abs(Element(e1).X - Element(e2).X) return maximum - def aos_as_list(self): + def aos_as_list(self) -> list[tuple[str, str, float]]: """The orbitals energies in eV are represented as [['O', '1s', -18.758245], ['O', '2s', -0.871362], ['O', '2p', -0.338381]] Data is obtained from @@ -73,11 +77,11 @@ def aos_as_list(self): A list of atomic orbitals, sorted from lowest to highest energy. """ return sorted( - chain.from_iterable([self.aos[el] * int(self.composition[el]) for el in self.elements]), + chain.from_iterable([self.aos[el] * int(self.composition[el]) for el in self.elements]), # type: ignore[misc] key=lambda x: x[2], ) - def obtain_band_edges(self): + def obtain_band_edges(self) -> dict[str, Any]: """Fill up the atomic orbitals with available electrons. Returns: diff --git a/pymatgen/core/operations.py b/pymatgen/core/operations.py index 2e2a8aa878c..a89b2e7a8fe 100644 --- a/pymatgen/core/operations.py +++ b/pymatgen/core/operations.py @@ -4,10 +4,9 @@ import re import string -import typing import warnings from math import cos, pi, sin, sqrt -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal, cast import numpy as np from monty.json import MSONable @@ -34,7 +33,11 @@ class SymmOp(MSONable): affine_matrix (np.ndarray): A 4x4 array representing the symmetry operation. """ - def __init__(self, affine_transformation_matrix: ArrayLike, tol: float = 0.01) -> None: + def __init__( + self, + affine_transformation_matrix: ArrayLike, + tol: float = 0.01, + ) -> None: """Initialize the SymmOp from a 4x4 affine transformation matrix. In general, this constructor should not be used unless you are transferring rotations. Use the static constructors instead to @@ -55,6 +58,26 @@ def __init__(self, affine_transformation_matrix: ArrayLike, tol: float = 0.01) - self.affine_matrix = affine_transformation_matrix self.tol = tol + def __eq__(self, other: object) -> bool: + if not isinstance(other, SymmOp): + return NotImplemented + return np.allclose(self.affine_matrix, other.affine_matrix, atol=self.tol) + + def __hash__(self) -> int: + return 7 + + def __repr__(self) -> str: + return f"{type(self).__name__}({self.affine_matrix=})" + + def __str__(self) -> str: + return "\n".join(["Rot:", str(self.affine_matrix[:3][:, :3]), "tau", str(self.affine_matrix[:3][:, 3])]) + + def __mul__(self, other) -> Self: + """Get a new SymmOp which is equivalent to apply the "other" SymmOp + followed by this one. + """ + return type(self)(np.dot(self.affine_matrix, other.affine_matrix)) + @classmethod def from_rotation_and_translation( cls, @@ -79,27 +102,12 @@ def from_rotation_and_translation( raise ValueError("Rotation Matrix must be a 3x3 numpy array.") if translation_vec.shape != (3,): raise ValueError("Translation vector must be a rank 1 numpy array with 3 elements.") + affine_matrix = np.eye(4) - affine_matrix[0:3][:, 0:3] = rotation_matrix - affine_matrix[0:3][:, 3] = translation_vec + affine_matrix[:3][:, :3] = rotation_matrix + affine_matrix[:3][:, 3] = translation_vec return cls(affine_matrix, tol) - def __eq__(self, other: object) -> bool: - if not isinstance(other, SymmOp): - return NotImplemented - return np.allclose(self.affine_matrix, other.affine_matrix, atol=self.tol) - - def __hash__(self) -> int: - return 7 - - def __repr__(self) -> str: - affine_matrix = self.affine_matrix - return f"{type(self).__name__}({affine_matrix=})" - - def __str__(self) -> str: - output = ["Rot:", str(self.affine_matrix[0:3][:, 0:3]), "tau", str(self.affine_matrix[0:3][:, 3])] - return "\n".join(output) - def operate(self, point: ArrayLike) -> np.ndarray: """Apply the operation on a point. @@ -109,8 +117,8 @@ def operate(self, point: ArrayLike) -> np.ndarray: Returns: Coordinates of point after operation. """ - affine_point = np.array([*point, 1]) # type: ignore - return np.dot(self.affine_matrix, affine_point)[0:3] + affine_point = np.array([*point, 1]) + return np.dot(self.affine_matrix, affine_point)[:3] def operate_multi(self, points: ArrayLike) -> np.ndarray: """Apply the operation on a list of points. @@ -147,6 +155,7 @@ def transform_tensor(self, tensor: np.ndarray) -> np.ndarray: dim = tensor.shape rank = len(dim) assert all(val == 3 for val in dim) + # Build einstein sum string lc = string.ascii_lowercase indices = lc[:rank], lc[rank : 2 * rank] @@ -156,7 +165,12 @@ def transform_tensor(self, tensor: np.ndarray) -> np.ndarray: return np.einsum(einsum_string, *einsum_args) - def are_symmetrically_related(self, point_a: ArrayLike, point_b: ArrayLike, tol: float = 0.001) -> bool: + def are_symmetrically_related( + self, + point_a: ArrayLike, + point_b: ArrayLike, + tol: float = 0.001, + ) -> bool: """Check if two points are symmetrically related. Args: @@ -218,19 +232,12 @@ def are_symmetrically_related_vectors( @property def rotation_matrix(self) -> np.ndarray: """A 3x3 numpy.array representing the rotation matrix.""" - return self.affine_matrix[0:3][:, 0:3] + return self.affine_matrix[:3][:, :3] @property def translation_vector(self) -> np.ndarray: """A rank 1 numpy.array of dim 3 representing the translation vector.""" - return self.affine_matrix[0:3][:, 3] - - def __mul__(self, other): - """Get a new SymmOp which is equivalent to apply the "other" SymmOp - followed by this one. - """ - new_matrix = np.dot(self.affine_matrix, other.affine_matrix) - return SymmOp(new_matrix) + return self.affine_matrix[:3][:, 3] @property def inverse(self) -> Self: @@ -263,24 +270,26 @@ def from_axis_angle_and_translation( ang = angle if angle_in_radians else angle * pi / 180 cos_a = cos(ang) sin_a = sin(ang) - unit_vec = axis / np.linalg.norm(axis) # type: ignore + unit_vec = axis / np.linalg.norm(axis) rot_mat = np.zeros((3, 3)) - rot_mat[0, 0] = cos_a + unit_vec[0] ** 2 * (1 - cos_a) # type: ignore - rot_mat[0, 1] = unit_vec[0] * unit_vec[1] * (1 - cos_a) - unit_vec[2] * sin_a # type: ignore - rot_mat[0, 2] = unit_vec[0] * unit_vec[2] * (1 - cos_a) + unit_vec[1] * sin_a # type: ignore - rot_mat[1, 0] = unit_vec[0] * unit_vec[1] * (1 - cos_a) + unit_vec[2] * sin_a # type: ignore - rot_mat[1, 1] = cos_a + unit_vec[1] ** 2 * (1 - cos_a) # type: ignore - rot_mat[1, 2] = unit_vec[1] * unit_vec[2] * (1 - cos_a) - unit_vec[0] * sin_a # type: ignore - rot_mat[2, 0] = unit_vec[0] * unit_vec[2] * (1 - cos_a) - unit_vec[1] * sin_a # type: ignore - rot_mat[2, 1] = unit_vec[1] * unit_vec[2] * (1 - cos_a) + unit_vec[0] * sin_a # type: ignore - rot_mat[2, 2] = cos_a + unit_vec[2] ** 2 * (1 - cos_a) # type: ignore + rot_mat[0, 0] = cos_a + unit_vec[0] ** 2 * (1 - cos_a) + rot_mat[0, 1] = unit_vec[0] * unit_vec[1] * (1 - cos_a) - unit_vec[2] * sin_a + rot_mat[0, 2] = unit_vec[0] * unit_vec[2] * (1 - cos_a) + unit_vec[1] * sin_a + rot_mat[1, 0] = unit_vec[0] * unit_vec[1] * (1 - cos_a) + unit_vec[2] * sin_a + rot_mat[1, 1] = cos_a + unit_vec[1] ** 2 * (1 - cos_a) + rot_mat[1, 2] = unit_vec[1] * unit_vec[2] * (1 - cos_a) - unit_vec[0] * sin_a + rot_mat[2, 0] = unit_vec[0] * unit_vec[2] * (1 - cos_a) - unit_vec[1] * sin_a + rot_mat[2, 1] = unit_vec[1] * unit_vec[2] * (1 - cos_a) + unit_vec[0] * sin_a + rot_mat[2, 2] = cos_a + unit_vec[2] ** 2 * (1 - cos_a) return SymmOp.from_rotation_and_translation(rot_mat, vec) - @typing.no_type_check @staticmethod def from_origin_axis_angle( - origin: ArrayLike, axis: ArrayLike, angle: float, angle_in_radians: bool = False + origin: ArrayLike, + axis: ArrayLike, + angle: float, + angle_in_radians: bool = False, ) -> SymmOp: """Generate a SymmOp for a rotation about a given axis through an origin. @@ -296,7 +305,7 @@ def from_origin_axis_angle( Returns: SymmOp. """ - theta = angle * pi / 180 if not angle_in_radians else angle + theta = angle if angle_in_radians else angle * pi / 180 a, b, c = origin ax_u, ax_v, ax_w = axis # Set some intermediate values. @@ -419,11 +428,15 @@ def as_xyz_str(self) -> str: """Get a string of the form 'x, y, z', '-x, -y, z', '-y+1/2, x+1/2, z+1/2', etc. Only works for integer rotation matrices. """ - # test for invalid rotation matrix + # Check for invalid rotation matrix if not np.all(np.isclose(self.rotation_matrix, np.round(self.rotation_matrix))): warnings.warn("Rotation matrix should be integer") - return transformation_to_string(self.rotation_matrix, translation_vec=self.translation_vector, delim=", ") + return transformation_to_string( + self.rotation_matrix, + translation_vec=self.translation_vector, + delim=", ", + ) @classmethod def from_xyz_str(cls, xyz_str: str) -> Self: @@ -440,17 +453,17 @@ def from_xyz_str(cls, xyz_str: str) -> Self: re_rot = re.compile(r"([+-]?)([\d\.]*)/?([\d\.]*)([x-z])") re_trans = re.compile(r"([+-]?)([\d\.]+)/?([\d\.]*)(?![x-z])") for idx, tok in enumerate(tokens): - # build the rotation matrix - for m in re_rot.finditer(tok): - factor = -1.0 if m.group(1) == "-" else 1.0 - if m.group(2) != "": - factor *= float(m.group(2)) / float(m.group(3)) if m.group(3) != "" else float(m.group(2)) - j = ord(m.group(4)) - 120 + # Build the rotation matrix + for match in re_rot.finditer(tok): + factor = -1.0 if match[1] == "-" else 1.0 + if match[2] != "": + factor *= float(match[2]) / float(match[3]) if match[3] != "" else float(match[2]) + j = ord(match[4]) - 120 rot_matrix[idx, j] = factor - # build the translation vector - for m in re_trans.finditer(tok): - factor = -1 if m.group(1) == "-" else 1 - num = float(m.group(2)) / float(m.group(3)) if m.group(3) != "" else float(m.group(2)) + # Build the translation vector + for match in re_trans.finditer(tok): + factor = -1 if match[1] == "-" else 1 + num = float(match[2]) / float(match[3]) if match[3] != "" else float(match[2]) trans[idx] = num * factor return cls.from_rotation_and_translation(rot_matrix, trans) @@ -473,7 +486,12 @@ class MagSymmOp(SymmOp): moment. """ - def __init__(self, affine_transformation_matrix: ArrayLike, time_reversal: int, tol: float = 0.01) -> None: + def __init__( + self, + affine_transformation_matrix: ArrayLike, + time_reversal: Literal[-1, 1], + tol: float = 0.01, + ) -> None: """Initialize the MagSymmOp from a 4x4 affine transformation matrix and time reversal operator. In general, this constructor should not be used unless you are transferring rotations. Use the static constructors instead to generate a SymmOp from proper rotations @@ -485,14 +503,14 @@ def __init__(self, affine_transformation_matrix: ArrayLike, time_reversal: int, time_reversal (int): 1 or -1 tol (float): Tolerance for determining if matrices are equal. """ - SymmOp.__init__(self, affine_transformation_matrix, tol=tol) + super().__init__(affine_transformation_matrix, tol=tol) if time_reversal in {-1, 1}: self.time_reversal = time_reversal else: raise RuntimeError(f"Invalid {time_reversal=}, must be 1 or -1") def __eq__(self, other: object) -> bool: - if not isinstance(other, SymmOp): + if not isinstance(other, type(self)): return NotImplemented return np.allclose(self.affine_matrix, other.affine_matrix, atol=self.tol) and ( self.time_reversal == other.time_reversal @@ -502,18 +520,19 @@ def __str__(self) -> str: return self.as_xyzt_str() def __repr__(self) -> str: - output = [ - "Rot:", - str(self.affine_matrix[0:3][:, 0:3]), - "tau", - str(self.affine_matrix[0:3][:, 3]), - "Time reversal:", - str(self.time_reversal), - ] - return "\n".join(output) + return "\n".join( + [ + "Rot:", + str(self.affine_matrix[:3][:, :3]), + "tau", + str(self.affine_matrix[:3][:, 3]), + "Time reversal:", + str(self.time_reversal), + ] + ) def __hash__(self) -> int: - # useful for obtaining a set of unique MagSymmOps + """Useful for obtaining a set of unique MagSymmOps.""" hashable_value = (*tuple(self.affine_matrix.flatten()), self.time_reversal) return hash(hashable_value) @@ -521,7 +540,7 @@ def __hash__(self) -> int: Doi("10.1051/epjconf/20122200010"), description="Symmetry and magnetic structures", ) - def operate_magmom(self, magmom): + def operate_magmom(self, magmom: Magmom) -> Magmom: """Apply time reversal operator on the magnetic moment. Note that magnetic moments transform as axial vectors, not polar vectors. @@ -535,17 +554,18 @@ class or as list or np array-like Returns: Magnetic moment after operator applied as Magmom class """ - magmom = Magmom(magmom) # type casting to handle lists as input + # Type casting to handle lists as input + magmom = Magmom(magmom) transformed_moment = ( self.apply_rotation_only(magmom.global_moment) * np.linalg.det(self.rotation_matrix) * self.time_reversal ) - # retains input spin axis if different from default + # Retain input spin axis if different from default return Magmom.from_global_moment_and_saxis(transformed_moment, magmom.saxis) @classmethod - def from_symmop(cls, symmop: SymmOp, time_reversal) -> Self: + def from_symmop(cls, symmop: SymmOp, time_reversal: Literal[-1, 1]) -> Self: """Initialize a MagSymmOp from a SymmOp and time reversal operator. Args: @@ -561,7 +581,7 @@ def from_symmop(cls, symmop: SymmOp, time_reversal) -> Self: def from_rotation_and_translation_and_time_reversal( rotation_matrix: ArrayLike = ((1, 0, 0), (0, 1, 0), (0, 0, 1)), translation_vec: ArrayLike = (0, 0, 0), - time_reversal: int = 1, + time_reversal: Literal[-1, 1] = 1, tol: float = 0.1, ) -> MagSymmOp: """Create a symmetry operation from a rotation matrix, translation @@ -596,13 +616,17 @@ def from_xyzt_str(cls, xyzt_str: str) -> Self: time_reversal = int(xyzt_str.rsplit(",", 1)[1]) except Exception: raise RuntimeError("Time reversal operator could not be parsed.") - return cls.from_symmop(symm_op, time_reversal) + + if time_reversal in {-1, 1}: + return cls.from_symmop(symm_op, cast(Literal[-1, 1], time_reversal)) + + raise RuntimeError("Time reversal should be -1 or 1.") def as_xyzt_str(self) -> str: """Get a string of the form 'x, y, z, +1', '-x, -y, z, -1', '-y+1/2, x+1/2, z+1/2, +1', etc. Only works for integer rotation matrices. """ - xyzt_string = SymmOp.as_xyz_str(self) + xyzt_string = super().as_xyz_str() return f"{xyzt_string}, {self.time_reversal:+}" def as_dict(self) -> dict[str, Any]: diff --git a/pymatgen/core/periodic_table.py b/pymatgen/core/periodic_table.py index 9fbc9ecfd6c..14694066ce9 100644 --- a/pymatgen/core/periodic_table.py +++ b/pymatgen/core/periodic_table.py @@ -1021,9 +1021,7 @@ def __lt__(self, other: object) -> bool: other_oxi = 0 if (isinstance(other, Element) or other.oxi_state is None) else other.oxi_state return self.oxi_state < other_oxi if self.spin is not None: - if other.spin is not None: - return self.spin < other.spin - return False + return self.spin < other.spin if other.spin is not None else False return False @@ -1043,8 +1041,8 @@ def __str__(self) -> str: output += f",{spin=}" return output - def __deepcopy__(self, memo) -> Species: - return Species(self.symbol, self.oxi_state, spin=self._spin) + def __deepcopy__(self, memo) -> Self: + return type(self)(self.symbol, self.oxi_state, spin=self._spin) @property def element(self) -> Element: @@ -1343,7 +1341,7 @@ def __str__(self) -> str: return output def __deepcopy__(self, memo) -> Self: - return DummySpecies(self.symbol, self._oxi_state) + return type(self)(self.symbol, self._oxi_state) @property def Z(self) -> int: diff --git a/pymatgen/core/sites.py b/pymatgen/core/sites.py index 26e83843e44..863ec8833b5 100644 --- a/pymatgen/core/sites.py +++ b/pymatgen/core/sites.py @@ -4,7 +4,7 @@ import collections import json -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import numpy as np from monty.json import MontyDecoder, MontyEncoder, MSONable @@ -15,10 +15,12 @@ from pymatgen.util.coord import pbc_diff if TYPE_CHECKING: + from typing import Any + from numpy.typing import ArrayLike from typing_extensions import Self - from pymatgen.util.typing import CompositionLike, SpeciesLike + from pymatgen.util.typing import CompositionLike, SpeciesLike, Vector3D class Site(collections.abc.Hashable, MSONable): @@ -61,41 +63,94 @@ def __init__( if not skip_checks: if not isinstance(species, Composition): try: - species = Composition({get_el_sp(species): 1}) # type: ignore + species = Composition({get_el_sp(species): 1}) # type: ignore[arg-type] except TypeError: species = Composition(species) total_occu = species.num_atoms if total_occu > 1 + Composition.amount_tolerance: raise ValueError("Species occupancies sum to more than 1!") coords = np.array(coords) - self._species: Composition = species # type: ignore - self.coords: np.ndarray = coords # type: ignore + + self._species = species + self.coords: np.ndarray = coords self.properties: dict = properties or {} self._label = label - def __getattr__(self, attr): - # overriding getattr doesn't play nicely with pickle, so we can't use self._properties + def __getattr__(self, attr: str) -> Any: + # Override getattr doesn't play nicely with pickle, + # so we can't use self._properties props = self.__getattribute__("properties") if attr in props: return props[attr] raise AttributeError(f"{attr=} not found on {type(self).__name__}") + def __getitem__(self, el: Element) -> float: # type: ignore[override] + """Get the occupancy for element.""" + return self.species[el] + + def __eq__(self, other: object) -> bool: + """Site is equal to another site if the species and occupancies are the + same, and the coordinates are the same to some tolerance. numpy + function `allclose` is used to determine if coordinates are close. + """ + if not isinstance(other, type(self)): + return NotImplemented + + return ( + self.species == other.species + and np.allclose(self.coords, other.coords, atol=type(self).position_atol) + and self.properties == other.properties + ) + + def __hash__(self) -> int: # type: ignore[override] + """Minimally effective hash function that just distinguishes between Sites + with different elements. + """ + return sum(el.Z for el in self.species) + + def __contains__(self, el: Element) -> bool: + return el in self.species + + def __repr__(self) -> str: + name = self.species_string + + if self.label != name: + name = f"{self.label} ({name})" + + return f"Site: {name} ({self.coords[0]:.4f}, {self.coords[1]:.4f}, {self.coords[2]:.4f})" + + def __lt__(self, other: Site) -> bool: + """Set a default sort order for atomic species by electronegativity. Very + useful for getting correct formulas. For example, FeO4PLi is + automatically sorted in LiFePO4. + """ + if self.species.average_electroneg < other.species.average_electroneg: + return True + if self.species.average_electroneg > other.species.average_electroneg: + return False + return self.species_string < other.species_string + + def __str__(self) -> str: + return f"{self.coords} {self.species_string}" + @property def species(self) -> Composition: """The species on the site as a composition, e.g. Fe0.5Mn0.5.""" - return self._species + return cast(Composition, self._species) @species.setter def species(self, species: SpeciesLike | CompositionLike) -> None: if not isinstance(species, Composition): try: - species = Composition({get_el_sp(species): 1}) # type: ignore + species = Composition({get_el_sp(species): 1}) # type: ignore[arg-type] except TypeError: species = Composition(species) + total_occu = species.num_atoms if total_occu > 1 + Composition.amount_tolerance: raise ValueError("Species occupancies sum to more than 1!") - self._species = species + + self._species = cast(Composition, species) @property def label(self) -> str: @@ -133,7 +188,7 @@ def z(self) -> float: def z(self, z: float) -> None: self.coords[2] = z - def distance(self, other) -> float: + def distance(self, other: Site) -> float: """Get distance between two sites. Args: @@ -144,7 +199,7 @@ def distance(self, other) -> float: """ return float(np.linalg.norm(other.coords - self.coords)) - def distance_from_point(self, pt) -> float: + def distance_from_point(self, pt: Vector3D) -> float: """Get distance between the site and a point in space. Args: @@ -185,55 +240,6 @@ def is_ordered(self) -> bool: total_occu = self.species.num_atoms return total_occu == len(self.species) == 1 - def __getitem__(self, el): - """Get the occupancy for element.""" - return self.species[el] - - def __eq__(self, other: object) -> bool: - """Site is equal to another site if the species and occupancies are the - same, and the coordinates are the same to some tolerance. numpy - function `allclose` is used to determine if coordinates are close. - """ - if not isinstance(other, type(self)): - return NotImplemented - - return ( - self.species == other.species - and np.allclose(self.coords, other.coords, atol=Site.position_atol) - and self.properties == other.properties - ) - - def __hash__(self) -> int: - """Minimally effective hash function that just distinguishes between Sites - with different elements. - """ - return sum(el.Z for el in self.species) - - def __contains__(self, el) -> bool: - return el in self.species - - def __repr__(self) -> str: - name = self.species_string - - if self.label != name: - name = f"{self.label} ({name})" - - return f"Site: {name} ({self.coords[0]:.4f}, {self.coords[1]:.4f}, {self.coords[2]:.4f})" - - def __lt__(self, other): - """Set a default sort order for atomic species by electronegativity. Very - useful for getting correct formulas. For example, FeO4PLi is - automatically sorted in LiFePO4. - """ - if self.species.average_electroneg < other.species.average_electroneg: - return True - if self.species.average_electroneg > other.species.average_electroneg: - return False - return self.species_string < other.species_string - - def __str__(self) -> str: - return f"{self.coords} {self.species_string}" - def as_dict(self) -> dict: """JSON-serializable dict representation for Site.""" species = [] @@ -264,11 +270,11 @@ def from_dict(cls, dct: dict) -> Self: atoms_n_occu = {} for sp_occu in dct["species"]: if "oxidation_state" in sp_occu and Element.is_valid_symbol(sp_occu["element"]): - sp = Species.from_dict(sp_occu) + sp: Species | DummySpecies | Element = Species.from_dict(sp_occu) elif "oxidation_state" in sp_occu: sp = DummySpecies.from_dict(sp_occu) else: - sp = Element(sp_occu["element"]) # type: ignore + sp = Element(sp_occu["element"]) atoms_n_occu[sp] = sp_occu["occu"] props = dct.get("properties") if props is not None: @@ -323,13 +329,13 @@ def __init__( frac_coords = lattice.get_fractional_coords(coords) if coords_are_cartesian else coords if to_unit_cell: - frac_coords = np.array([np.mod(f, 1) if p else f for p, f in zip(lattice.pbc, frac_coords)]) # type: ignore + frac_coords = np.array([np.mod(f, 1) if p else f for p, f in zip(lattice.pbc, frac_coords)]) if not skip_checks: frac_coords = np.array(frac_coords) if not isinstance(species, Composition): try: - species = Composition({get_el_sp(species): 1}) # type: ignore + species = Composition({get_el_sp(species): 1}) # type: ignore[arg-type] except TypeError: species = Composition(species) @@ -339,7 +345,7 @@ def __init__( self._lattice: Lattice = lattice self._frac_coords: np.ndarray = np.asarray(frac_coords) - self._species: Composition = species # type: ignore + self._species: Composition = cast(Composition, species) self._coords: np.ndarray | None = None self.properties: dict = properties or {} self._label = label @@ -350,6 +356,28 @@ def __hash__(self) -> int: """ return sum(el.Z for el in self.species) + def __eq__(self, other: object) -> bool: + if not isinstance(other, type(self)): + return NotImplemented + + return ( + self.species == other.species + and self.lattice == other.lattice + and np.allclose(self.coords, other.coords, atol=Site.position_atol) + and self.properties == other.properties + ) + + def __repr__(self) -> str: + name = self.species_string + + if self.label != name: + name = f"{self.label} ({name})" + + x, y, z = self.coords + x_frac, y_frac, z_frac = map(float, self.frac_coords) + cls_name = type(self).__name__ + return f"{cls_name}: {name} ({x:.4}, {y:.4}, {z:.4}) [{x_frac:.4}, {y_frac:.4}, {z_frac:.4}]" + @property def lattice(self) -> Lattice: """Lattice associated with PeriodicSite.""" @@ -369,7 +397,7 @@ def coords(self) -> np.ndarray: return self._coords @coords.setter - def coords(self, coords) -> None: + def coords(self, coords: np.ndarray) -> None: """Set Cartesian coordinates.""" self._coords = np.array(coords) self._frac_coords = self._lattice.get_fractional_coords(self._coords) @@ -380,7 +408,7 @@ def frac_coords(self) -> np.ndarray: return self._frac_coords @frac_coords.setter - def frac_coords(self, frac_coords) -> None: + def frac_coords(self, frac_coords: np.ndarray) -> None: """Set fractional coordinates.""" self._frac_coords = np.array(frac_coords) self._coords = self._lattice.get_cartesian_coords(self._frac_coords) @@ -445,16 +473,21 @@ def z(self, z: float) -> None: self.coords[2] = z self._frac_coords = self._lattice.get_fractional_coords(self.coords) - def to_unit_cell(self, in_place=False) -> PeriodicSite | None: + def to_unit_cell(self, in_place: bool = False) -> PeriodicSite | None: """Move frac coords to within the unit cell.""" frac_coords = [np.mod(f, 1) if p else f for p, f in zip(self.lattice.pbc, self.frac_coords)] if in_place: self.frac_coords = np.array(frac_coords) return None - return PeriodicSite(self.species, frac_coords, self.lattice, properties=self.properties, label=self.label) + return type(self)(self.species, frac_coords, self.lattice, properties=self.properties, label=self.label) - def is_periodic_image(self, other: PeriodicSite, tolerance: float = 1e-8, check_lattice: bool = True) -> bool: - """Get True if sites are periodic images of each other. + def is_periodic_image( + self, + other: PeriodicSite, + tolerance: float = 1e-8, + check_lattice: bool = True, + ) -> bool: + """Check if sites are periodic images of each other. Args: other (PeriodicSite): Other site @@ -473,19 +506,10 @@ def is_periodic_image(self, other: PeriodicSite, tolerance: float = 1e-8, check_ frac_diff = pbc_diff(self.frac_coords, other.frac_coords, self.lattice.pbc) return np.allclose(frac_diff, [0, 0, 0], atol=tolerance) - def __eq__(self, other: object) -> bool: - if not isinstance(other, Site): - return NotImplemented - - return ( - self.species == other.species - and self.lattice == other.lattice - and np.allclose(self.coords, other.coords, atol=Site.position_atol) - and self.properties == other.properties - ) - def distance_and_image_from_frac_coords( - self, fcoords: ArrayLike, jimage: ArrayLike | None = None + self, + fcoords: ArrayLike, + jimage: ArrayLike | None = None, ) -> tuple[float, np.ndarray]: """Get distance between site and a fractional coordinate assuming periodic boundary conditions. If the index jimage of two sites atom j @@ -508,7 +532,11 @@ def distance_and_image_from_frac_coords( """ return self.lattice.get_distance_and_image(self.frac_coords, fcoords, jimage=jimage) - def distance_and_image(self, other: PeriodicSite, jimage: ArrayLike | None = None) -> tuple[float, np.ndarray]: + def distance_and_image( + self, + other: PeriodicSite, + jimage: ArrayLike | None = None, + ) -> tuple[float, np.ndarray]: """Get distance and instance between two sites assuming periodic boundary conditions. If the index jimage of two sites atom j is not specified it selects the j image nearest to the i atom and returns the distance and @@ -529,7 +557,11 @@ def distance_and_image(self, other: PeriodicSite, jimage: ArrayLike | None = Non """ return self.distance_and_image_from_frac_coords(other.frac_coords, jimage) - def distance(self, other: PeriodicSite, jimage: ArrayLike | None = None): + def distance( + self, + other: PeriodicSite, + jimage: ArrayLike | None = None, + ) -> float: """Get distance between two sites assuming periodic boundary conditions. Args: @@ -544,17 +576,6 @@ def distance(self, other: PeriodicSite, jimage: ArrayLike | None = None): """ return self.distance_and_image(other, jimage)[0] - def __repr__(self) -> str: - name = self.species_string - - if self.label != name: - name = f"{self.label} ({name})" - - x, y, z = self.coords - x_frac, y_frac, z_frac = map(float, self.frac_coords) - cls_name = type(self).__name__ - return f"{cls_name}: {name} ({x:.4}, {y:.4}, {z:.4}) [{x_frac:.4}, {y_frac:.4}, {z_frac:.4}]" - def as_dict(self, verbosity: int = 0) -> dict: """JSON-serializable dict representation of PeriodicSite. @@ -572,7 +593,7 @@ def as_dict(self, verbosity: int = 0) -> dict: dct = { "species": species, - "abc": [float(c) for c in self._frac_coords], # type: ignore + "abc": [float(c) for c in self._frac_coords], "lattice": self._lattice.as_dict(verbosity=verbosity), "@module": type(self).__module__, "@class": type(self).__name__, @@ -587,7 +608,7 @@ def as_dict(self, verbosity: int = 0) -> dict: return dct @classmethod - def from_dict(cls, dct, lattice=None) -> Self: + def from_dict(cls, dct: dict, lattice: Lattice | None = None) -> Self: """Create PeriodicSite from dict representation. Args: @@ -602,12 +623,13 @@ def from_dict(cls, dct, lattice=None) -> Self: species = {} for sp_occu in dct["species"]: if "oxidation_state" in sp_occu and Element.is_valid_symbol(sp_occu["element"]): - sp = Species.from_dict(sp_occu) + sp: Species | DummySpecies | Element = Species.from_dict(sp_occu) elif "oxidation_state" in sp_occu: sp = DummySpecies.from_dict(sp_occu) else: - sp = Element(sp_occu["element"]) # type: ignore + sp = Element(sp_occu["element"]) species[sp] = sp_occu["occu"] + props = dct.get("properties") if props is not None: for key in props: diff --git a/pymatgen/core/spectrum.py b/pymatgen/core/spectrum.py index 0fce3e757b1..e581d4d3ac3 100644 --- a/pymatgen/core/spectrum.py +++ b/pymatgen/core/spectrum.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Literal +from typing import TYPE_CHECKING import numpy as np from monty.json import MSONable @@ -14,12 +14,15 @@ from pymatgen.util.coord import get_linear_interpolated_value if TYPE_CHECKING: - from numpy.typing import ArrayLike + from typing import Callable, Literal + + from numpy.typing import NDArray from typing_extensions import Self -def lorentzian(x, x_0: float = 0, sigma: float = 1.0): - """ +def lorentzian(x: NDArray, x_0: float = 0, sigma: float = 1.0) -> NDArray: + """The Lorentzian smearing function. + Args: x: x values x_0: Center @@ -32,7 +35,7 @@ def lorentzian(x, x_0: float = 0, sigma: float = 1.0): class Spectrum(MSONable): - """Base class for any type of xas, essentially just x, y values. Examples + """Base class for any type of XAS, essentially just x, y values. Examples include XRD patterns, XANES, EXAFS, NMR, DOS, etc. Implements basic tools like application of smearing, normalization, addition @@ -46,7 +49,7 @@ class Spectrum(MSONable): XLABEL = "x" YLABEL = "y" - def __init__(self, x: ArrayLike, y: ArrayLike, *args, **kwargs) -> None: + def __init__(self, x: NDArray, y: NDArray, *args, **kwargs) -> None: """ Args: x (ndarray): A ndarray of N values. @@ -66,7 +69,7 @@ def __init__(self, x: ArrayLike, y: ArrayLike, *args, **kwargs) -> None: self._args = args self._kwargs = kwargs - def __getattr__(self, name): + def __getattr__(self, name: str) -> NDArray: if name == self.XLABEL.lower(): return self.x if name == self.YLABEL.lower(): @@ -76,72 +79,7 @@ def __getattr__(self, name): def __len__(self) -> int: return self.ydim[0] - def normalize(self, mode: Literal["max", "sum"] = "max", value: float = 1.0) -> None: - """Normalize the spectrum with respect to the sum of intensity. - - Args: - mode ("max" | "sum"): Normalization mode. "max" sets the max y value to value, - e.g. in XRD patterns. "sum" sets the sum of y to a value, i.e., like a - probability density. - value (float): Value to normalize to. Defaults to 1. - """ - if mode.lower() == "sum": - factor = np.sum(self.y, axis=0) - elif mode.lower() == "max": - factor = np.max(self.y, axis=0) - else: - raise ValueError(f"Unsupported normalization {mode=}!") - - self.y /= factor / value - - def smear(self, sigma: float = 0.0, func: str | Callable = "gaussian") -> None: - """Apply Gaussian/Lorentzian smearing to spectrum y value. - - Args: - sigma: Std dev for Gaussian smear function - func: "gaussian" or "lorentzian" or a callable. If this is a callable, the sigma value is ignored. The - callable should only take a single argument (a numpy array) and return a set of weights. - """ - points = np.linspace(np.min(self.x) - np.mean(self.x), np.max(self.x) - np.mean(self.x), len(self.x)) - if callable(func): - weights = func(points) - elif func.lower() == "gaussian": - weights = stats.norm.pdf(points, scale=sigma) - elif func.lower() == "lorentzian": - weights = lorentzian(points, sigma=sigma) - else: - raise ValueError(f"Invalid {func=}") - weights /= np.sum(weights) - if len(self.ydim) == 1: - total = np.sum(self.y) - self.y = convolve1d(self.y, weights) - self.y *= total / np.sum(self.y) # renormalize to maintain the same integrated sum as before. - else: - total = np.sum(self.y, axis=0) - self.y = np.array([convolve1d(self.y[:, k], weights) for k in range(self.ydim[1])]).T - self.y *= total / np.sum(self.y, axis=0) # renormalize to maintain the same integrated sum as before. - - def get_interpolated_value(self, x: float) -> float | list[float]: - """Get an interpolated y value for a particular x value. - - Args: - x: x value to return the y value for - - Returns: - Value of y at x - """ - if len(self.ydim) == 1: - return get_linear_interpolated_value(self.x, self.y, x) - return [get_linear_interpolated_value(self.x, self.y[:, k], x) for k in range(self.ydim[1])] - - def copy(self) -> Self: - """ - Returns: - Copy of Spectrum object. - """ - return type(self)(self.x, self.y, *self._args, **self._kwargs) - - def __add__(self, other): + def __add__(self, other: Spectrum) -> Self: """Add two Spectrum object together. Checks that x scales are the same. Otherwise, a ValueError is thrown. @@ -153,9 +91,10 @@ def __add__(self, other): """ if not all(np.equal(self.x, other.x)): raise ValueError("X axis values are not compatible!") + return type(self)(self.x, self.y + other.y, *self._args, **self._kwargs) - def __sub__(self, other): + def __sub__(self, other: Spectrum) -> Self: """Subtract one Spectrum object from another. Checks that x scales are the same. Otherwise, a ValueError is thrown. @@ -168,9 +107,10 @@ def __sub__(self, other): """ if not all(np.equal(self.x, other.x)): raise ValueError("X axis values are not compatible!") + return type(self)(self.x, self.y - other.y, *self._args, **self._kwargs) - def __mul__(self, other): + def __mul__(self, other: Spectrum) -> Self: """Scale the Spectrum's y values. Args: @@ -183,7 +123,7 @@ def __mul__(self, other): __rmul__ = __mul__ - def __truediv__(self, other): + def __truediv__(self, other: Spectrum) -> Self: """True division of y. Args: @@ -194,7 +134,7 @@ def __truediv__(self, other): """ return type(self)(self.x, self.y.__truediv__(other), *self._args, **self._kwargs) - def __floordiv__(self, other): + def __floordiv__(self, other: Spectrum) -> Self: """True division of y. Args: @@ -208,11 +148,84 @@ def __floordiv__(self, other): __div__ = __truediv__ def __str__(self) -> str: - """Get a string containing values and labels of spectrum object for + """String containing values and labels of spectrum object for plotting. """ return f"{type(self).__name__}\n{self.XLABEL}: {self.x}\n{self.YLABEL}: {self.y}" def __repr__(self) -> str: - """Get a printable representation of the class.""" + """A printable representation of the class.""" return str(self) + + def normalize( + self, + mode: Literal["max", "sum"] = "max", + value: float = 1.0, + ) -> None: + """Normalize the spectrum with respect to the sum of intensity. + + Args: + mode ("max" | "sum"): Normalization mode. "max" sets the max y value to value, + e.g. in XRD patterns. "sum" sets the sum of y to a value, i.e., like a + probability density. + value (float): Value to normalize to. Defaults to 1. + """ + if mode.lower() == "sum": + factor = np.sum(self.y, axis=0) + elif mode.lower() == "max": + factor = np.max(self.y, axis=0) + else: + raise ValueError(f"Unsupported normalization {mode=}!") + + self.y /= factor / value + + def smear( + self, + sigma: float = 0.0, + func: Literal["gaussian", "lorentzian"] | Callable = "gaussian", + ) -> None: + """Apply Gaussian/Lorentzian smearing to spectrum y value. + + Args: + sigma: Std dev for Gaussian smear function + func: "gaussian" or "lorentzian" or a callable. If this is a callable, the sigma value is ignored. The + callable should only take a single argument (a numpy array) and return a set of weights. + """ + points = np.linspace(np.min(self.x) - np.mean(self.x), np.max(self.x) - np.mean(self.x), len(self.x)) + if callable(func): + weights = func(points) + elif func.lower() == "gaussian": + weights = stats.norm.pdf(points, scale=sigma) + elif func.lower() == "lorentzian": + weights = lorentzian(points, sigma=sigma) + else: + raise ValueError(f"Invalid {func=}") + weights /= np.sum(weights) + if len(self.ydim) == 1: + total = np.sum(self.y) + self.y = convolve1d(self.y, weights) + self.y *= total / np.sum(self.y) # renormalize to maintain the same integrated sum as before. + else: + total = np.sum(self.y, axis=0) + self.y = np.array([convolve1d(self.y[:, k], weights) for k in range(self.ydim[1])]).T + self.y *= total / np.sum(self.y, axis=0) # renormalize to maintain the same integrated sum as before. + + def get_interpolated_value(self, x: float) -> float | list[float]: + """Get an interpolated y value for a particular x value. + + Args: + x: x value to return the y value for + + Returns: + Value of y at x + """ + if len(self.ydim) == 1: + return get_linear_interpolated_value(self.x, self.y, x) + return [get_linear_interpolated_value(self.x, self.y[:, k], x) for k in range(self.ydim[1])] + + def copy(self) -> Self: + """ + Returns: + Copy of Spectrum object. + """ + return type(self)(self.x, self.y, *self._args, **self._kwargs) diff --git a/pymatgen/core/structure.py b/pymatgen/core/structure.py index 125c2020f23..61b84690166 100644 --- a/pymatgen/core/structure.py +++ b/pymatgen/core/structure.py @@ -102,7 +102,7 @@ def __len__(self) -> Literal[3]: """Make neighbor Tuple-like to retain backwards compatibility.""" return 3 - def __getitem__(self, idx: int): + def __getitem__(self, idx: int) -> Self | float: # type: ignore[override] """Make neighbor Tuple-like to retain backwards compatibility.""" return (self, self.nn_distance, self.index)[idx] @@ -165,19 +165,19 @@ def __init__( self.image = image self._label = label - @property # type: ignore - def coords(self) -> np.ndarray: # type: ignore - """Cartesian coords.""" - return self._lattice.get_cartesian_coords(self._frac_coords) - def __len__(self) -> int: """Make neighbor Tuple-like to retain backwards compatibility.""" return 4 - def __getitem__(self, idx: int | slice): + def __getitem__(self, idx: int | slice): # type: ignore[override] """Make neighbor Tuple-like to retain backwards compatibility.""" return (self, self.nn_distance, self.index, self.image)[idx] + @property # type: ignore + def coords(self) -> np.ndarray: # type: ignore + """Cartesian coords.""" + return self._lattice.get_cartesian_coords(self._frac_coords) + def as_dict(self) -> dict: # type: ignore[override] """Note that method calls the super of Site, which is MSONable itself.""" return super(Site, self).as_dict() @@ -206,6 +206,23 @@ class SiteCollection(collections.abc.Sequence, ABC): DISTANCE_TOLERANCE = 0.5 _properties: dict + def __contains__(self, site: object) -> bool: + return site in self.sites + + def __iter__(self) -> Iterator[Site]: + return iter(self.sites) + + # TODO return type needs fixing (can be list[Site] but raises lots of mypy errors) + def __getitem__(self, ind: int | slice) -> Site: + return self.sites[ind] # type: ignore[return-value] + + def __len__(self) -> int: + return len(self.sites) + + def __hash__(self) -> int: + # for now, just use the composition hash code. + return hash(self.composition) + @property def sites(self) -> list[Site]: """An iterator for the sites in the Structure.""" @@ -362,23 +379,6 @@ def relabel_sites(self, ignore_uniq: bool = False) -> Self: return self - def __contains__(self, site: object) -> bool: - return site in self.sites - - def __iter__(self) -> Iterator[Site]: - return iter(self.sites) - - # TODO return type needs fixing (can be list[Site] but raises lots of mypy errors) - def __getitem__(self, ind: int | slice) -> Site: - return self.sites[ind] # type: ignore[return-value] - - def __len__(self) -> int: - return len(self.sites) - - def __hash__(self) -> int: - # for now, just use the composition hash code. - return hash(self.composition) - @property def num_sites(self) -> int: """Number of sites.""" @@ -1046,6 +1046,124 @@ def __init__( self._charge = charge self._properties = properties or {} + def __eq__(self, other: object) -> bool: + needed_attrs = ("lattice", "sites", "properties") + + if not all(hasattr(other, attr) for attr in needed_attrs): + # return NotImplemented as in https://docs.python.org/3/library/functools.html#functools.total_ordering + return NotImplemented + + other = cast(Structure, other) # make mypy happy + + if other is self: + return True + if len(self) != len(other): + return False + if self.lattice != other.lattice: + return False + if self.properties != other.properties: + return False + return all(site in other for site in self) + + def __hash__(self) -> int: + # For now, just use the composition hash code. + return hash(self.composition) + + def __mul__(self, scaling_matrix: int | Sequence[int] | Sequence[Sequence[int]]) -> Structure: + """Make a supercell. Allowing to have sites outside the unit cell. + + Args: + scaling_matrix: A scaling matrix for transforming the lattice + vectors. Has to be all integers. Several options are possible: + + a. A full 3x3 scaling matrix defining the linear combination + of the old lattice vectors. e.g. [[2,1,0],[0,3,0],[0,0, + 1]] generates a new structure with lattice vectors a' = + 2a + b, b' = 3b, c' = c where a, b, and c are the lattice + vectors of the original structure. + b. A sequence of three scaling factors. e.g. [2, 1, 1] + specifies that the supercell should have dimensions 2a x b x + c. + c. A number, which simply scales all lattice vectors by the + same factor. + + Returns: + Supercell structure. Note that a Structure is always returned, + even if the input structure is a subclass of Structure. This is + to avoid different arguments signatures from causing problems. If + you prefer a subclass to return its own type, you need to override + this method in the subclass. + """ + scale_matrix = np.array(scaling_matrix, int) + if scale_matrix.shape != (3, 3): + scale_matrix = scale_matrix * np.eye(3) + new_lattice = Lattice(np.dot(scale_matrix, self.lattice.matrix)) + + frac_lattice = lattice_points_in_supercell(scale_matrix) + cart_lattice = new_lattice.get_cartesian_coords(frac_lattice) + + new_sites = [] + for site in self: + for vec in cart_lattice: + periodic_site = PeriodicSite( + site.species, + site.coords + vec, + new_lattice, + properties=site.properties, + coords_are_cartesian=True, + to_unit_cell=False, + skip_checks=True, + label=site.label, + ) + new_sites.append(periodic_site) + + new_charge = self._charge * np.linalg.det(scale_matrix) if self._charge else None + return Structure.from_sites(new_sites, charge=new_charge, to_unit_cell=True) + + def __rmul__(self, scaling_matrix): + """Similar to __mul__ to preserve commutativeness.""" + return self * scaling_matrix + + def __repr__(self) -> str: + outs = ["Structure Summary", repr(self.lattice)] + if self._charge: + outs.append(f"Overall Charge: {self._charge:+}") + for site in self: + outs.append(repr(site)) + return "\n".join(outs) + + def __str__(self) -> str: + def to_str(x) -> str: + return f"{x:>10.6f}" + + outs = [ + f"Full Formula ({self.composition.formula})", + f"Reduced Formula: {self.composition.reduced_formula}", + f"abc : {' '.join(to_str(i) for i in self.lattice.abc)}", + f"angles: {' '.join(to_str(i) for i in self.lattice.angles)}", + f"pbc : {' '.join(str(p).rjust(10) for p in self.lattice.pbc)}", + ] + + if self._charge: + outs.append(f"Overall Charge: {self._charge:+}") + outs.append(f"Sites ({len(self)})") + data = [] + props = self.site_properties + keys = sorted(props) + for idx, site in enumerate(self): + row = [str(idx), site.species_string] + row.extend([to_str(j) for j in site.frac_coords]) + for key in keys: + row.append(props[key][idx]) + data.append(row) + outs.append( + tabulate( + data, + headers=["#", "SP", "a", "b", "c", *keys], + ) + ) + return "\n".join(outs) + @classmethod def from_sites( cls, @@ -1406,84 +1524,6 @@ def matches(self, other: IStructure | Structure, anonymous: bool = False, **kwar return matcher.fit_anonymous(self, other) return matcher.fit(self, other) - def __eq__(self, other: object) -> bool: - needed_attrs = ("lattice", "sites", "properties") - - if not all(hasattr(other, attr) for attr in needed_attrs): - # return NotImplemented as in https://docs.python.org/3/library/functools.html#functools.total_ordering - return NotImplemented - - other = cast(Structure, other) # make mypy happy - - if other is self: - return True - if len(self) != len(other): - return False - if self.lattice != other.lattice: - return False - if self.properties != other.properties: - return False - return all(site in other for site in self) - - def __hash__(self) -> int: - # For now, just use the composition hash code. - return hash(self.composition) - - def __mul__(self, scaling_matrix: int | Sequence[int] | Sequence[Sequence[int]]) -> Structure: - """Make a supercell. Allowing to have sites outside the unit cell. - - Args: - scaling_matrix: A scaling matrix for transforming the lattice - vectors. Has to be all integers. Several options are possible: - - a. A full 3x3 scaling matrix defining the linear combination - of the old lattice vectors. e.g. [[2,1,0],[0,3,0],[0,0, - 1]] generates a new structure with lattice vectors a' = - 2a + b, b' = 3b, c' = c where a, b, and c are the lattice - vectors of the original structure. - b. A sequence of three scaling factors. e.g. [2, 1, 1] - specifies that the supercell should have dimensions 2a x b x - c. - c. A number, which simply scales all lattice vectors by the - same factor. - - Returns: - Supercell structure. Note that a Structure is always returned, - even if the input structure is a subclass of Structure. This is - to avoid different arguments signatures from causing problems. If - you prefer a subclass to return its own type, you need to override - this method in the subclass. - """ - scale_matrix = np.array(scaling_matrix, int) - if scale_matrix.shape != (3, 3): - scale_matrix = scale_matrix * np.eye(3) - new_lattice = Lattice(np.dot(scale_matrix, self.lattice.matrix)) - - frac_lattice = lattice_points_in_supercell(scale_matrix) - cart_lattice = new_lattice.get_cartesian_coords(frac_lattice) - - new_sites = [] - for site in self: - for vec in cart_lattice: - periodic_site = PeriodicSite( - site.species, - site.coords + vec, - new_lattice, - properties=site.properties, - coords_are_cartesian=True, - to_unit_cell=False, - skip_checks=True, - label=site.label, - ) - new_sites.append(periodic_site) - - new_charge = self._charge * np.linalg.det(scale_matrix) if self._charge else None - return Structure.from_sites(new_sites, charge=new_charge, to_unit_cell=True) - - def __rmul__(self, scaling_matrix): - """Similar to __mul__ to preserve commutativeness.""" - return self * scaling_matrix - @property def frac_coords(self): """Fractional coordinates as a Nx3 numpy array.""" @@ -2584,46 +2624,6 @@ def factors(n: int): return self.copy() - def __repr__(self) -> str: - outs = ["Structure Summary", repr(self.lattice)] - if self._charge: - outs.append(f"Overall Charge: {self._charge:+}") - for site in self: - outs.append(repr(site)) - return "\n".join(outs) - - def __str__(self) -> str: - def to_str(x) -> str: - return f"{x:>10.6f}" - - outs = [ - f"Full Formula ({self.composition.formula})", - f"Reduced Formula: {self.composition.reduced_formula}", - f"abc : {' '.join(to_str(i) for i in self.lattice.abc)}", - f"angles: {' '.join(to_str(i) for i in self.lattice.angles)}", - f"pbc : {' '.join(str(p).rjust(10) for p in self.lattice.pbc)}", - ] - - if self._charge: - outs.append(f"Overall Charge: {self._charge:+}") - outs.append(f"Sites ({len(self)})") - data = [] - props = self.site_properties - keys = sorted(props) - for idx, site in enumerate(self): - row = [str(idx), site.species_string] - row.extend([to_str(j) for j in site.frac_coords]) - for key in keys: - row.append(props[key][idx]) - data.append(row) - outs.append( - tabulate( - data, - headers=["#", "SP", "a", "b", "c", *keys], - ) - ) - return "\n".join(outs) - def get_orderings(self, mode: Literal["enum", "sqs"] = "enum", **kwargs) -> list[Structure]: """Get list of orderings for a disordered structure. If structure does not contain disorder, the default structure is returned. @@ -3165,6 +3165,42 @@ def __init__( self._spin_multiplicity = 1 if n_electrons % 2 == 0 else 2 self.properties = properties or {} + def __eq__(self, other: object) -> bool: + needed_attrs = ("charge", "spin_multiplicity", "sites", "properties") + + if not all(hasattr(other, attr) for attr in needed_attrs): + return NotImplemented + + other = cast(IMolecule, other) + + if len(self) != len(other): + return False + if self.charge != other.charge: + return False + if self.spin_multiplicity != other.spin_multiplicity: + return False + if self.properties != other.properties: + return False + return all(site in other for site in self) + + def __hash__(self) -> int: + # For now, just use the composition hash code. + return hash(self.composition) + + def __repr__(self) -> str: + return "Molecule Summary\n" + "\n".join(map(repr, self)) + + def __str__(self) -> str: + outs = [ + f"Full Formula ({self.composition.formula})", + "Reduced Formula: " + self.composition.reduced_formula, + f"Charge = {self._charge}, Spin Mult = {self._spin_multiplicity}", + f"Sites ({len(self)})", + ] + for idx, site in enumerate(self): + outs.append(f"{idx} {site.species_string} {' '.join([f'{coord:0.6f}'.rjust(12) for coord in site.coords])}") + return "\n".join(outs) + @property def charge(self) -> float: """Charge of molecule.""" @@ -3310,24 +3346,6 @@ def get_covalent_bonds(self, tol: float = 0.2) -> list[CovalentBond]: bonds.append(CovalentBond(site1, site2)) return bonds - def __eq__(self, other: object) -> bool: - needed_attrs = ("charge", "spin_multiplicity", "sites", "properties") - - if not all(hasattr(other, attr) for attr in needed_attrs): - return NotImplemented - - other = cast(IMolecule, other) - - if len(self) != len(other): - return False - if self.charge != other.charge: - return False - if self.spin_multiplicity != other.spin_multiplicity: - return False - if self.properties != other.properties: - return False - return all(site in other for site in self) - def get_zmatrix(self): """Get a z-matrix representation of the molecule.""" # TODO: allow more z-matrix conventions for element/site description @@ -3363,24 +3381,6 @@ def _find_nn_pos_before_site(self, site_idx): all_dist = sorted(all_dist, key=lambda x: x[0]) return [d[1] for d in all_dist] - def __hash__(self) -> int: - # For now, just use the composition hash code. - return hash(self.composition) - - def __repr__(self) -> str: - return "Molecule Summary\n" + "\n".join(map(repr, self)) - - def __str__(self) -> str: - outs = [ - f"Full Formula ({self.composition.formula})", - "Reduced Formula: " + self.composition.reduced_formula, - f"Charge = {self._charge}, Spin Mult = {self._spin_multiplicity}", - f"Sites ({len(self)})", - ] - for idx, site in enumerate(self): - outs.append(f"{idx} {site.species_string} {' '.join([f'{coord:0.6f}'.rjust(12) for coord in site.coords])}") - return "\n".join(outs) - def as_dict(self): """JSON-serializable dict representation of Molecule.""" dct = { diff --git a/pymatgen/core/tensors.py b/pymatgen/core/tensors.py index 954764eea06..83aacf3e00b 100644 --- a/pymatgen/core/tensors.py +++ b/pymatgen/core/tensors.py @@ -1,6 +1,6 @@ -"""This module provides a base class for tensor-like objects and methods for -basic tensor manipulation. It also provides a class, SquareTensor, -that provides basic methods for creating and manipulating rank 2 tensors. +"""This module provides a base class Tensor for tensor-like objects and +methods for basic tensor manipulation. It also provides SquareTensor, +which provides basic methods for creating and manipulating rank 2 tensors. """ from __future__ import annotations @@ -24,7 +24,9 @@ if TYPE_CHECKING: from collections.abc import Sequence + from typing import Any + from numpy.typing import NDArray from typing_extensions import Self from pymatgen.core import Structure @@ -33,9 +35,6 @@ __credits__ = "Maarten de Jong, Shyam Dwaraknath, Wei Chen, Mark Asta, Anubhav Jain, Terence Lew" -voigt_map = [(0, 0), (1, 1), (2, 2), (1, 2), (0, 2), (0, 1)] -reverse_voigt_map = np.array([[0, 5, 4], [5, 1, 3], [4, 3, 2]]) - DEFAULT_QUAD = loadfn(os.path.join(os.path.dirname(__file__), "quad_data.json")) @@ -46,7 +45,12 @@ class Tensor(np.ndarray, MSONable): symbol = "T" - def __new__(cls, input_array, vscale=None, check_rank=None) -> Self: + def __new__( + cls, + input_array: NDArray, + vscale: NDArray | None = None, + check_rank: int | None = None, + ) -> Self: """Create a Tensor object. Note that the constructor uses __new__ rather than __init__ according to the standard method of subclassing numpy ndarrays. @@ -71,7 +75,7 @@ def __new__(cls, input_array, vscale=None, check_rank=None) -> Self: obj._vscale = vscale if obj._vscale.shape != vshape: raise ValueError("Voigt scaling matrix must be the shape of the Voigt notation matrix or vector.") - if not all(dim == 3 for dim in obj.shape): + if any(dim != 3 for dim in obj.shape): raise ValueError( "Pymatgen only supports 3-dimensional tensors, and default tensor constructor uses standard " f"notation. To construct from Voigt notation, use {type(obj).__name__}.from_voigt" @@ -100,13 +104,13 @@ def __hash__(self) -> int: def __repr__(self) -> str: return f"{type(self).__name__}({self})" - def zeroed(self, tol: float = 1e-3): + def zeroed(self, tol: float = 1e-3) -> Self: """Get the matrix with all entries below a certain threshold (i.e. tol) set to zero.""" new_tensor = self.copy() new_tensor[abs(new_tensor) < tol] = 0 return new_tensor - def transform(self, symm_op): + def transform(self, symm_op: SymmOp) -> Self: """Apply a transformation (via a symmetry operation) to a tensor. Args: @@ -114,7 +118,7 @@ def transform(self, symm_op): """ return type(self)(symm_op.transform_tensor(self)) - def rotate(self, matrix, tol: float = 1e-3): + def rotate(self, matrix: NDArray, tol: float = 1e-3) -> Self: """Apply a rotation directly, and tests input matrix to ensure a valid rotation. @@ -128,7 +132,11 @@ def rotate(self, matrix, tol: float = 1e-3): symm_op = SymmOp.from_rotation_and_translation(matrix, [0.0, 0.0, 0.0]) return self.transform(symm_op) - def einsum_sequence(self, other_arrays, einsum_string=None): + def einsum_sequence( + self, + other_arrays: NDArray, + einsum_string: str | None = None, + ) -> NDArray: """Calculate the result of an einstein summation expression.""" if not isinstance(other_arrays, list): raise ValueError("other tensors must be list of tensors or tensor input") @@ -140,16 +148,15 @@ def einsum_sequence(self, other_arrays, einsum_string=None): other_ranks = [len(a.shape) for a in other_arrays] idx = self.rank - sum(other_ranks) for length in other_ranks: - einsum_string += "," + lc[idx : idx + length] + einsum_string += f",{lc[idx : idx + length]}" idx += length einsum_args = [self, *other_arrays] return np.einsum(einsum_string, *einsum_args) - def project(self, n): - """Convenience method for projection of a tensor into a - vector. Returns the tensor dotted into a unit vector - along the input n. + def project(self, n: NDArray) -> Self: + """Project a tensor into a vector. Returns the tensor + dotted into a unit vector along the input n. Args: n (3x1 array-like): direction to project onto @@ -161,7 +168,7 @@ def project(self, n): unit_vec = get_uvec(n) return self.einsum_sequence([unit_vec] * self.rank) - def average_over_unit_sphere(self, quad=None): + def average_over_unit_sphere(self, quad: dict | None = None) -> Self: """Average the tensor projection over the unit with option for custom quadrature. Args: @@ -176,7 +183,7 @@ def average_over_unit_sphere(self, quad=None): weights, points = quad["weights"], quad["points"] return sum(w * self.project(n) for w, n in zip(weights, points)) - def get_grouped_indices(self, voigt=False, **kwargs): + def get_grouped_indices(self, voigt: bool = False, **kwargs) -> list[list]: """Get index sets for equivalent tensor values. Args: @@ -211,7 +218,12 @@ def get_grouped_indices(self, voigt=False, **kwargs): # Don't return any empty lists return [g for g in grouped if g] - def get_symbol_dict(self, voigt=True, zero_index=False, **kwargs): + def get_symbol_dict( + self, + voigt: bool = True, + zero_index: bool = False, + **kwargs, + ) -> dict[str, NDArray]: """Create a summary dict for tensor with associated symbol. Args: @@ -246,12 +258,12 @@ def get_symbol_dict(self, voigt=True, zero_index=False, **kwargs): dct[sym_string] = array[indices[0]] return dct - def round(self, decimals=0): + def round(self, decimals: int = 0) -> Self: """Wrapper around numpy.round to ensure object of same type is returned. Args: - decimals :Number of decimal places to round to (default: 0). + decimals: Number of decimal places to round to (default: 0). If decimals is negative, it specifies the number of positions to the left of the decimal point. @@ -261,7 +273,7 @@ def round(self, decimals=0): return type(self)(np.round(self, decimals=decimals)) @property - def symmetrized(self): + def symmetrized(self) -> Self: """A generally symmetrized tensor, calculated by taking the sum of the tensor and its transpose with respect to all possible permutations of indices. @@ -270,11 +282,11 @@ def symmetrized(self): return sum(np.transpose(self, ind) for ind in perms) / len(perms) @property - def voigt_symmetrized(self): + def voigt_symmetrized(self) -> Self: """A "voigt"-symmetrized tensor, i. e. a Voigt-notation tensor such that it is invariant w.r.t. permutation of indices. """ - if not (self.rank % 2 == 0 and self.rank >= 2): + if self.rank % 2 != 0 or self.rank < 2: raise ValueError("V-symmetrization requires rank even and >= 2") v = self.voigt @@ -282,7 +294,7 @@ def voigt_symmetrized(self): new_v = sum(np.transpose(v, ind) for ind in perms) / len(perms) return type(self).from_voigt(new_v) - def is_symmetric(self, tol: float = 1e-5): + def is_symmetric(self, tol: float = 1e-5) -> bool: """Test whether a tensor is symmetric or not based on the residual with its symmetric part, from self.symmetrized. @@ -291,7 +303,11 @@ def is_symmetric(self, tol: float = 1e-5): """ return (self - self.symmetrized < tol).all() - def fit_to_structure(self, structure: Structure, symprec: float = 0.1): + def fit_to_structure( + self, + structure: Structure, + symprec: float = 0.1, + ): """Get a tensor that is invariant with respect to symmetry operations corresponding to a structure. @@ -305,7 +321,7 @@ def fit_to_structure(self, structure: Structure, symprec: float = 0.1): symm_ops = sga.get_symmetry_operations(cartesian=True) return sum(self.transform(symm_op) for symm_op in symm_ops) / len(symm_ops) - def is_fit_to_structure(self, structure: Structure, tol: float = 1e-2): + def is_fit_to_structure(self, structure: Structure, tol: float = 1e-2) -> bool: """Test whether a tensor is invariant with respect to the symmetry operations of a particular structure by testing whether the residual of the symmetric portion is below a @@ -318,7 +334,7 @@ def is_fit_to_structure(self, structure: Structure, tol: float = 1e-2): return (self - self.fit_to_structure(structure) < tol).all() @property - def voigt(self): + def voigt(self) -> NDArray: """The tensor in Voigt notation.""" v_matrix = np.zeros(self._vscale.shape, dtype=self.dtype) this_voigt_map = self.get_voigt_dict(self.rank) @@ -333,7 +349,7 @@ def is_voigt_symmetric(self, tol: float = 1e-6) -> bool: by grouping indices into pairs and constructing a sequence of possible permutations to be used in a tensor transpose. """ - transpose_pieces = [[[0 for i in range(self.rank % 2)]]] + transpose_pieces = [[[0 for _ in range(self.rank % 2)]]] transpose_pieces += [[list(range(j, j + 2))] for j in range(self.rank % 2, self.rank, 2)] for n in range(self.rank % 2, len(transpose_pieces)): if len(transpose_pieces[n][0]) == 2: @@ -345,13 +361,15 @@ def is_voigt_symmetric(self, tol: float = 1e-6) -> bool: return True @staticmethod - def get_voigt_dict(rank): + def get_voigt_dict(rank: int) -> dict[tuple[int, ...], tuple[int, ...]]: """Get a dictionary that maps indices in the tensor to those in a voigt representation based on input rank. Args: rank (int): Tensor rank to generate the voigt map """ + reverse_voigt_map = np.array([[0, 5, 4], [5, 1, 3], [4, 3, 2]]) + voigt_dict = {} for ind in itertools.product(*[range(3)] * rank): v_ind = ind[: rank % 2] @@ -362,7 +380,7 @@ def get_voigt_dict(rank): return voigt_dict @classmethod - def from_voigt(cls, voigt_input) -> Self: + def from_voigt(cls, voigt_input: NDArray) -> Self: """Constructor based on the voigt notation vector or matrix. Args: @@ -380,7 +398,10 @@ def from_voigt(cls, voigt_input) -> Self: return cls(t) @staticmethod - def get_ieee_rotation(structure, refine_rotation=True): + def get_ieee_rotation( + structure: Structure, + refine_rotation: bool = True, + ) -> SquareTensor: """Given a structure associated with a tensor, determines the rotation matrix for IEEE conversion according to the 1987 IEEE standards. @@ -452,7 +473,12 @@ def get_ieee_rotation(structure, refine_rotation=True): return rotation - def convert_to_ieee(self, structure: Structure, initial_fit=True, refine_rotation=True): + def convert_to_ieee( + self, + structure: Structure, + initial_fit: bool = True, + refine_rotation: bool = True, + ) -> Self: """Given a structure associated with a tensor, attempts a calculation of the tensor in IEEE format according to the 1987 IEEE standards. @@ -475,7 +501,12 @@ def convert_to_ieee(self, structure: Structure, initial_fit=True, refine_rotatio result = result.fit_to_structure(structure) return result.rotate(rotation, tol=1e-2) - def structure_transform(self, original_structure, new_structure, refine_rotation=True): + def structure_transform( + self, + original_structure: Structure, + new_structure: Structure, + refine_rotation: bool = True, + ) -> Self: """Transforms a tensor from one basis for an original structure into a new basis defined by a new structure. @@ -504,13 +535,13 @@ def structure_transform(self, original_structure, new_structure, refine_rotation @classmethod def from_values_indices( cls, - values, - indices, - populate=False, - structure=None, - voigt_rank=None, - vsym=True, - verbose=False, + values: list[float], + indices: NDArray, + populate: bool = False, + structure: Structure | None = None, + voigt_rank: int | None = None, + vsym: bool = True, + verbose: bool = False, ) -> Self: """Create a tensor from values and indices, with options for populating the remainder of the tensor. @@ -530,7 +561,7 @@ def from_values_indices( optimization procedure verbose (bool): whether to populate verbosely """ - # auto-detect voigt notation + # Auto-detect voigt notation # TODO: refactor rank inheritance to make this easier indices = np.array(indices) if voigt_rank: @@ -559,7 +590,7 @@ def populate( verbose: bool = False, precond: bool = True, vsym: bool = True, - ) -> Tensor: + ) -> Self: """Takes a partially populated tensor, and populates the non-zero entries according to the following procedure, iterated until the desired convergence (specified via prec) is achieved. @@ -582,7 +613,7 @@ def populate( Returns: Tensor: Populated tensor """ - guess = Tensor(np.zeros(self.shape)) + guess = type(self)(np.zeros(self.shape)) mask = None if precond: # Generate the guess from populated @@ -613,7 +644,7 @@ def merge(old, new) -> None: for perm in perms: vtrans = np.transpose(v, perm) merge(v, vtrans) - guess = Tensor.from_voigt(v) + guess = type(self).from_voigt(v) assert guess.shape == self.shape, "Guess must have same shape" converged = False @@ -689,7 +720,7 @@ def __getitem__(self, ind): def __iter__(self): return iter(self.tensors) - def zeroed(self, tol: float = 1e-3): + def zeroed(self, tol: float = 1e-3) -> Self: """ Args: tol: Tolerance. @@ -699,7 +730,7 @@ def zeroed(self, tol: float = 1e-3): """ return type(self)([tensor.zeroed(tol) for tensor in self]) - def transform(self, symm_op): + def transform(self, symm_op: SymmOp) -> Self: """Transforms TensorCollection with a symmetry operation. Args: @@ -710,7 +741,7 @@ def transform(self, symm_op): """ return type(self)([tensor.transform(symm_op) for tensor in self]) - def rotate(self, matrix, tol: float = 1e-3): + def rotate(self, matrix, tol: float = 1e-3) -> Self: """Rotates TensorCollection. Args: @@ -723,7 +754,7 @@ def rotate(self, matrix, tol: float = 1e-3): return type(self)([tensor.rotate(matrix, tol) for tensor in self]) @property - def symmetrized(self): + def symmetrized(self) -> Self: """TensorCollection where all tensors are symmetrized.""" return type(self)([tensor.symmetrized for tensor in self]) @@ -737,7 +768,11 @@ def is_symmetric(self, tol: float = 1e-5) -> bool: """ return all(tensor.is_symmetric(tol) for tensor in self) - def fit_to_structure(self, structure: Structure, symprec: float = 0.1): + def fit_to_structure( + self, + structure: Structure, + symprec: float = 0.1, + ) -> Self: """Fit all tensors to a Structure. Args: @@ -749,7 +784,11 @@ def fit_to_structure(self, structure: Structure, symprec: float = 0.1): """ return type(self)([tensor.fit_to_structure(structure, symprec) for tensor in self]) - def is_fit_to_structure(self, structure: Structure, tol: float = 1e-2): + def is_fit_to_structure( + self, + structure: Structure, + tol: float = 1e-2, + ) -> bool: """ Args: structure: Structure @@ -761,12 +800,12 @@ def is_fit_to_structure(self, structure: Structure, tol: float = 1e-2): return all(tensor.is_fit_to_structure(structure, tol) for tensor in self) @property - def voigt(self): + def voigt(self) -> list[NDArray]: """TensorCollection where all tensors are in Voigt form.""" return [tensor.voigt for tensor in self] @property - def ranks(self): + def ranks(self) -> list: """Ranks for all tensors.""" return [tensor.rank for tensor in self] @@ -781,7 +820,11 @@ def is_voigt_symmetric(self, tol: float = 1e-6) -> bool: return all(tensor.is_voigt_symmetric(tol) for tensor in self) @classmethod - def from_voigt(cls, voigt_input_list, base_class=Tensor) -> Self: + def from_voigt( + cls, + voigt_input_list: list[Tensor], + base_class=Tensor, + ) -> Self: """Create TensorCollection from voigt form. Args: @@ -793,7 +836,12 @@ def from_voigt(cls, voigt_input_list, base_class=Tensor) -> Self: """ return cls([base_class.from_voigt(v) for v in voigt_input_list]) - def convert_to_ieee(self, structure: Structure, initial_fit=True, refine_rotation=True): + def convert_to_ieee( + self, + structure: Structure, + initial_fit: bool = True, + refine_rotation: bool = True, + ) -> Self: """Convert all tensors to IEEE. Args: @@ -806,7 +854,7 @@ def convert_to_ieee(self, structure: Structure, initial_fit=True, refine_rotatio """ return type(self)([tensor.convert_to_ieee(structure, initial_fit, refine_rotation) for tensor in self]) - def round(self, *args, **kwargs): + def round(self, *args, **kwargs) -> Self: """Round all tensors. Args: @@ -819,11 +867,11 @@ def round(self, *args, **kwargs): return type(self)([tensor.round(*args, **kwargs) for tensor in self]) @property - def voigt_symmetrized(self): + def voigt_symmetrized(self) -> Self: """TensorCollection where all tensors are voigt symmetrized.""" return type(self)([tensor.voigt_symmetrized for tensor in self]) - def as_dict(self, voigt=False): + def as_dict(self, voigt: bool = False) -> dict: """ Args: voigt: Whether to use Voigt form. @@ -832,7 +880,7 @@ def as_dict(self, voigt=False): Dict representation of TensorCollection. """ tensor_list = self.voigt if voigt else self - dct = { + dct: dict[str, Any] = { "@module": type(self).__module__, "@class": type(self).__name__, "tensor_list": [tensor.tolist() for tensor in tensor_list], @@ -851,8 +899,7 @@ def from_dict(cls, dct: dict) -> Self: Returns: TensorCollection """ - voigt = dct.get("voigt") - if voigt: + if dct.get("voigt"): return cls.from_voigt(dct["tensor_list"]) return cls(dct["tensor_list"]) @@ -862,7 +909,11 @@ class SquareTensor(Tensor): (stress, strain etc.). """ - def __new__(cls, input_array, vscale=None) -> Self: + def __new__( + cls, + input_array: NDArray, + vscale: NDArray | None = None, + ) -> Self: """Create a SquareTensor object. Note that the constructor uses __new__ rather than __init__ according to the standard method of subclassing numpy ndarrays. Error is thrown when the class is initialized with non-square matrix. @@ -877,23 +928,27 @@ def __new__(cls, input_array, vscale=None) -> Self: return obj.view(cls) @property - def trans(self): + def trans(self) -> Self: """Shorthand for transpose on SquareTensor.""" - return SquareTensor(np.transpose(self)) + return type(self)(np.transpose(self)) @property - def inv(self): + def inv(self) -> Self: """Shorthand for matrix inverse on SquareTensor.""" if self.det == 0: raise ValueError("SquareTensor is non-invertible") - return SquareTensor(np.linalg.inv(self)) + return type(self)(np.linalg.inv(self)) @property - def det(self): + def det(self) -> Self: """Shorthand for the determinant of the SquareTensor.""" return np.linalg.det(self) - def is_rotation(self, tol: float = 1e-3, include_improper=True): + def is_rotation( + self, + tol: float = 1e-3, + include_improper: bool = True, + ) -> bool: """Test to see if tensor is a valid rotation matrix, performs a test to check whether the inverse is equal to the transpose and if the determinant is equal to one within the specified @@ -911,7 +966,7 @@ def is_rotation(self, tol: float = 1e-3, include_improper=True): det = np.abs(det) return (np.abs(self.inv - self.trans) < tol).all() and (np.abs(det - 1.0) < tol) - def refine_rotation(self): + def refine_rotation(self) -> Self: """Helper method for refining rotation matrix by ensuring that second and third rows are perpendicular to the first. Gets new y vector from an orthogonal projection of x onto y @@ -927,39 +982,42 @@ def refine_rotation(self): # Get a projection on y new_y = y - np.dot(new_x, y) * new_x new_z = np.cross(new_x, new_y) - return SquareTensor([new_x, new_y, new_z]) + return type(self)([new_x, new_y, new_z]) - def get_scaled(self, scale_factor): + def get_scaled(self, scale_factor: float) -> Self: """Scales the tensor by a certain multiplicative scale factor. Args: scale_factor (float): scalar multiplier to be applied to the SquareTensor object """ - return SquareTensor(self * scale_factor) + return type(self)(self * scale_factor) @property - def principal_invariants(self): + def principal_invariants(self) -> NDArray: """A list of principal invariants for the tensor, which are the values of the coefficients of the characteristic polynomial for the matrix. """ return np.poly(self)[1:] * np.array([-1, 1, -1]) - def polar_decomposition(self, side="right"): + def polar_decomposition(self, side: str = "right") -> tuple: """Calculate matrices for polar decomposition.""" return polar(self, side=side) -def get_uvec(vec: np.ndarray) -> np.ndarray: +def get_uvec(vec: NDArray) -> NDArray: """Get a unit vector parallel to input vector.""" norm = np.linalg.norm(vec) - if norm < 1e-8: - return vec - return vec / norm + return vec if norm < 1e-8 else vec / norm -def symmetry_reduce(tensors, structure: Structure, tol: float = 1e-8, **kwargs): +def symmetry_reduce( + tensors, + structure: Structure, + tol: float = 1e-8, + **kwargs, +) -> TensorMapping: """Convert a list of tensors corresponding to a structure and returns a dictionary consisting of unique tensor keys with SymmOp values corresponding to transformations that will result in derivative @@ -1003,7 +1061,12 @@ class TensorMapping(collections.abc.MutableMapping): and should be used with care. """ - def __init__(self, tensors: Sequence[Tensor] = (), values: Sequence = (), tol: float = 1e-5) -> None: + def __init__( + self, + tensors: Sequence[Tensor] = (), + values: Sequence = (), + tol: float = 1e-5, + ) -> None: """Initialize a TensorMapping. Args: @@ -1046,6 +1109,9 @@ def __len__(self) -> int: def __iter__(self): yield from self._tensor_list + def __contains__(self, item) -> bool: + return self._get_item_index(item) is not None + def values(self): """Values in mapping.""" return self._value_list @@ -1054,9 +1120,6 @@ def items(self): """Items in mapping.""" return zip(self._tensor_list, self._value_list) - def __contains__(self, item) -> bool: - return self._get_item_index(item) is not None - def _get_item_index(self, item): if len(self._tensor_list) == 0: return None diff --git a/pymatgen/core/xcfunc.py b/pymatgen/core/xcfunc.py index 6df4c540e00..4a6a2939158 100644 --- a/pymatgen/core/xcfunc.py +++ b/pymatgen/core/xcfunc.py @@ -16,7 +16,7 @@ __author__ = "Matteo Giantomassi" __copyright__ = "Copyright 2016, The Materials Project" -__version__ = "3.0.0" # The libxc version used to generate this file! +__version__ = "3.0.0" # The libxc version used to generate this file __maintainer__ = "Matteo Giantomassi" __email__ = "gmatteo@gmail.com" __status__ = "Production" @@ -120,7 +120,12 @@ class type_name(NamedTuple): del xcf - def __init__(self, xc: LibxcFunc | None = None, x: LibxcFunc | None = None, c: LibxcFunc | None = None) -> None: + def __init__( + self, + xc: LibxcFunc | None = None, + x: LibxcFunc | None = None, + c: LibxcFunc | None = None, + ) -> None: """ Args: xc: LibxcFunc for XC functional. @@ -136,6 +141,20 @@ def __init__(self, xc: LibxcFunc | None = None, x: LibxcFunc | None = None, c: L self.xc, self.x, self.c = xc, x, c + def __repr__(self) -> str: + return str(self.name) + + def __hash__(self) -> int: + return hash(self.name) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, (str, type(self))): + return NotImplemented + if isinstance(other, type(self)): + return self.name == other.name + # Assume other is a string + return self.name == other + @classmethod def aliases(cls) -> list[str]: """List of registered names.""" @@ -143,12 +162,12 @@ def aliases(cls) -> list[str]: @classmethod def asxc(cls, obj) -> Self: - """Convert object into Xcfunc.""" + """Convert object into XcFunc.""" if isinstance(obj, cls): return obj if isinstance(obj, str): return cls.from_name(obj) - raise TypeError(f"Don't know how to convert <{type(obj)}:{obj}> to Xcfunc") + raise TypeError(f"Don't know how to convert <{type(obj)}:{obj}> to XcFunc") @classmethod def from_abinit_ixc(cls, ixc: int) -> Self | None: @@ -158,7 +177,8 @@ def from_abinit_ixc(cls, ixc: int) -> Self | None: if ixc > 0: return cls(**cls.abinitixc_to_libxc[ixc]) - # libxc notation employed in Abinit: a six-digit number in the form XXXCCC or CCCXXX + # libxc notation employed in Abinit: a six-digit number + # in the form XXXCCC or CCCXXX ixc = abs(ixc) first = ixc // 1000 last = ixc - first * 1000 @@ -177,7 +197,7 @@ def from_name(cls, name: str) -> Self: @classmethod def from_type_name(cls, typ: str | None, name: str) -> Self: """Build the object from (type, name).""" - # Try aliases first. + # Try aliases first for k, nt in cls.defined_aliases.items(): if typ is not None and typ != nt.type: continue @@ -250,17 +270,3 @@ def name(self) -> str | None: return f"{self.x.name}+{self.c.name}" return None - - def __repr__(self) -> str: - return str(self.name) - - def __hash__(self) -> int: - return hash(self.name) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, (str, XcFunc)): - return NotImplemented - if isinstance(other, XcFunc): - return self.name == other.name - # assume other is a string - return self.name == other diff --git a/pymatgen/io/cp2k/inputs.py b/pymatgen/io/cp2k/inputs.py index 46ebe635fa3..ee8bcb5b6d3 100644 --- a/pymatgen/io/cp2k/inputs.py +++ b/pymatgen/io/cp2k/inputs.py @@ -29,7 +29,6 @@ import os import re import textwrap -import typing from collections.abc import Iterable, Sequence from dataclasses import dataclass, field from pathlib import Path @@ -2393,7 +2392,6 @@ def nexp(self): """Number of exponents.""" return [len(exp) for exp in self.exponents] - @typing.no_type_check def get_str(self) -> str: """Get standard cp2k GTO formatted string.""" if ( # written verbosely so mypy can perform type narrowing diff --git a/pymatgen/io/cssr.py b/pymatgen/io/cssr.py index 42d3e276e4c..e4391b449bc 100644 --- a/pymatgen/io/cssr.py +++ b/pymatgen/io/cssr.py @@ -76,7 +76,7 @@ def from_str(cls, string: str) -> Self: tokens = lines[0].split() lengths = [float(tok) for tok in tokens] tokens = lines[1].split() - angles = [float(tok) for tok in tokens[0:3]] + angles = [float(tok) for tok in tokens[:3]] lattice = Lattice.from_parameters(*lengths, *angles) sp, coords = [], [] for line in lines[4:]: diff --git a/pymatgen/io/fiesta.py b/pymatgen/io/fiesta.py index 25f80c9c7b3..03c142462df 100644 --- a/pymatgen/io/fiesta.py +++ b/pymatgen/io/fiesta.py @@ -686,7 +686,7 @@ def from_str(cls, string_input: str) -> Self: while i != 0: line = lines.pop(0).strip() tokens = line.split() - coords.append([float(j) for j in tokens[0:3]]) + coords.append([float(j) for j in tokens[:3]]) species.append(atname[int(tokens[3]) - 1]) i -= 1 diff --git a/pymatgen/io/qchem/outputs.py b/pymatgen/io/qchem/outputs.py index 7a97d4c2885..5c238fe85e2 100644 --- a/pymatgen/io/qchem/outputs.py +++ b/pymatgen/io/qchem/outputs.py @@ -2409,7 +2409,7 @@ def parse_hyperbonds(lines: list[str]) -> list[pd.DataFrame]: # Extract the values entry: dict[str, str | float] = {} - entry["hyperbond index"] = int(line[0:4].strip()) + entry["hyperbond index"] = int(line[:4].strip()) entry["bond atom 1 symbol"] = str(line[5:8].strip()) entry["bond atom 1 index"] = int(line[8:11].strip()) entry["bond atom 2 symbol"] = str(line[13:15].strip()) @@ -2492,7 +2492,7 @@ def parse_hybridization_character(lines: list[str]) -> list[pd.DataFrame]: # Lone pair if "LP" in line or "LV" in line: LPentry: dict[str, str | float] = dict.fromkeys(orbitals, 0.0) - LPentry["bond index"] = line[0:4].strip() + LPentry["bond index"] = line[:4].strip() LPentry["occupancy"] = line[7:14].strip() LPentry["type"] = line[16:19].strip() LPentry["orbital index"] = line[20:22].strip() @@ -2521,7 +2521,7 @@ def parse_hybridization_character(lines: list[str]) -> list[pd.DataFrame]: BDentry: dict[str, str | float] = { f"atom {i} {orbital}": 0.0 for orbital in orbitals for i in range(1, 3) } - BDentry["bond index"] = line[0:4].strip() + BDentry["bond index"] = line[:4].strip() BDentry["occupancy"] = line[7:14].strip() BDentry["type"] = line[16:19].strip() BDentry["orbital index"] = line[20:22].strip() @@ -2581,7 +2581,7 @@ def parse_hybridization_character(lines: list[str]) -> list[pd.DataFrame]: TCentry: dict[str, str | float] = { f"atom {i} {orbital}": 0.0 for orbital in orbitals for i in range(1, 4) } - TCentry["bond index"] = line[0:4].strip() + TCentry["bond index"] = line[:4].strip() TCentry["occupancy"] = line[7:14].strip() TCentry["type"] = line[16:19].strip() TCentry["orbital index"] = line[20:22].strip() @@ -2725,7 +2725,7 @@ def parse_perturbation_energy(lines: list[str]) -> list[pd.DataFrame]: second_3C = True if line[4] == ".": - entry["donor bond index"] = int(line[0:4].strip()) + entry["donor bond index"] = int(line[:4].strip()) entry["donor type"] = str(line[5:9].strip()) entry["donor orbital index"] = int(line[10:12].strip()) entry["donor atom 1 symbol"] = str(line[13:15].strip()) @@ -2743,7 +2743,7 @@ def parse_perturbation_energy(lines: list[str]) -> list[pd.DataFrame]: entry["energy difference"] = float(line[62:70].strip()) entry["fock matrix element"] = float(line[70:79].strip()) elif line[5] == ".": - entry["donor bond index"] = int(line[0:5].strip()) + entry["donor bond index"] = int(line[:5].strip()) entry["donor type"] = str(line[6:10].strip()) entry["donor orbital index"] = int(line[11:13].strip()) diff --git a/pymatgen/io/xr.py b/pymatgen/io/xr.py index 789651e66e4..8c28325c113 100644 --- a/pymatgen/io/xr.py +++ b/pymatgen/io/xr.py @@ -96,7 +96,7 @@ def from_str(cls, string: str, use_cores: bool = True, thresh: float = 1.0e-4) - tokens = lines[0].split() lengths = [float(tokens[i]) for i in range(1, len(tokens))] tokens = lines[1].split() - angles = [float(i) for i in tokens[0:3]] + angles = [float(i) for i in tokens[:3]] tokens = lines[2].split() n_sites = int(tokens[0]) mat = np.zeros((3, 3), dtype=float) diff --git a/pymatgen/io/zeopp.py b/pymatgen/io/zeopp.py index 2fba4cf677a..70f8d13e141 100644 --- a/pymatgen/io/zeopp.py +++ b/pymatgen/io/zeopp.py @@ -112,7 +112,7 @@ def from_str(cls, string: str) -> Self: tokens = lines[0].split() lengths = [float(i) for i in tokens] tokens = lines[1].split() - angles = [float(i) for i in tokens[0:3]] + angles = [float(i) for i in tokens[:3]] # Zeo++ takes x-axis along a and pymatgen takes z-axis along c a = lengths.pop(-1) lengths.insert(0, a) diff --git a/pymatgen/phonon/bandstructure.py b/pymatgen/phonon/bandstructure.py index 34dfa175d5a..abc1c391d4e 100644 --- a/pymatgen/phonon/bandstructure.py +++ b/pymatgen/phonon/bandstructure.py @@ -254,7 +254,7 @@ def asr_breaking(self, tol_eigendisplacements: float = 1e-5) -> np.ndarray | Non modes: selects the bands corresponding to the eigendisplacements that represent to a translation within tol_eigendisplacements. If these are not identified or eigendisplacements are missing the first 3 modes will be used - (indices [0:3]). + (indices [:3]). """ for idx in range(self.nb_qpoints): if np.allclose(self.qpoints[idx].frac_coords, (0, 0, 0)): diff --git a/pymatgen/transformations/advanced_transformations.py b/pymatgen/transformations/advanced_transformations.py index 519864c27e2..3e44b713df6 100644 --- a/pymatgen/transformations/advanced_transformations.py +++ b/pymatgen/transformations/advanced_transformations.py @@ -9,7 +9,7 @@ from itertools import groupby, product from math import gcd from string import ascii_lowercase -from typing import TYPE_CHECKING, Callable, Literal +from typing import TYPE_CHECKING import numpy as np from joblib import Parallel, delayed @@ -47,7 +47,7 @@ if TYPE_CHECKING: from collections.abc import Iterable, Sequence - from typing import Any + from typing import Any, Callable, Literal __author__ = "Shyue Ping Ong, Stephen Dacek, Anubhav Jain, Matthew Horton, Alex Ganose" @@ -861,8 +861,8 @@ def apply_transformation( alls = self._add_spin_magnitudes(alls) # type: ignore[arg-type] else: for idx in range(len(alls)): - alls[idx]["structure"] = self._remove_dummy_species(alls[idx]["structure"]) - alls[idx]["structure"] = self._add_spin_magnitudes(alls[idx]["structure"]) + alls[idx]["structure"] = self._remove_dummy_species(alls[idx]["structure"]) # type: ignore[index] + alls[idx]["structure"] = self._add_spin_magnitudes(alls[idx]["structure"]) # type: ignore[index, arg-type] try: num_to_return = int(return_ranked_list) @@ -870,7 +870,7 @@ def apply_transformation( num_to_return = 1 if num_to_return == 1 or not return_ranked_list: - return alls[0]["structure"] if num_to_return else alls # type: ignore[return-value] + return alls[0]["structure"] if num_to_return else alls # type: ignore[return-value, index] # remove duplicate structures and group according to energy model matcher = StructureMatcher(comparator=SpinComparator()) @@ -879,7 +879,7 @@ def key(struct: Structure) -> int: return SpacegroupAnalyzer(struct, 0.1).get_space_group_number() out = [] - for _, group in groupby(sorted((dct["structure"] for dct in alls), key=key), key): + for _, group in groupby(sorted((dct["structure"] for dct in alls), key=key), key): # type: ignore[arg-type, index] group = list(group) # type: ignore grouped = matcher.group_structures(group) out.extend([{"structure": g[0], "energy": self.energy_model.get_energy(g[0])} for g in grouped]) diff --git a/tests/core/test_structure.py b/tests/core/test_structure.py index 78435463f6f..9e7a999c2b7 100644 --- a/tests/core/test_structure.py +++ b/tests/core/test_structure.py @@ -938,7 +938,7 @@ def test_mutable_sequence_methods(self): # Test slice replacement. struct = PymatgenTest.get_structure("Li2O") - struct[0:2] = "S" + struct[:2] = "S" assert struct.formula == "Li1 S2" def test_not_hashable(self): @@ -1559,7 +1559,7 @@ def test_set_item(self): assert struct.formula == "Si1 C1" struct[(0, 1)] = "Ge" assert struct.formula == "Ge2" - struct[0:2] = "Sn" + struct[:2] = "Sn" assert struct.formula == "Sn2" struct = self.struct.copy() @@ -1891,7 +1891,7 @@ def test_set_item(self): assert mol.formula == "Si1 H4" mol[(0, 1)] = "Ge" assert mol.formula == "Ge2 H3" - mol[0:2] = "Sn" + mol[:2] = "Sn" assert mol.formula == "Sn2 H3" mol = self.mol.copy() diff --git a/tests/core/test_trajectory.py b/tests/core/test_trajectory.py index 4687042c0b7..03fe63181fe 100644 --- a/tests/core/test_trajectory.py +++ b/tests/core/test_trajectory.py @@ -89,8 +89,8 @@ def test_slice(self): else: raise AssertionError - sliced_traj = self.traj_mols[0:2] - sliced_traj_from_mols = Trajectory.from_molecules(self.molecules[0:2]) + sliced_traj = self.traj_mols[:2] + sliced_traj_from_mols = Trajectory.from_molecules(self.molecules[:2]) if len(sliced_traj) == len(sliced_traj_from_mols): assert all(sliced_traj[i] == sliced_traj_from_mols[i] for i in range(len(sliced_traj))) diff --git a/tests/core/test_units.py b/tests/core/test_units.py index 140fba2c409..00cebd8272c 100644 --- a/tests/core/test_units.py +++ b/tests/core/test_units.py @@ -237,7 +237,7 @@ def test_array_algebra(self): ene_ha * time_s, ene_ha / ene_ev, ene_ha.copy(), - ene_ha[0:1], + ene_ha[:1], e1, e2, e3, diff --git a/tests/io/test_cif.py b/tests/io/test_cif.py index e6ae3f06f21..5f0a897f0bf 100644 --- a/tests/io/test_cif.py +++ b/tests/io/test_cif.py @@ -1091,7 +1091,7 @@ def test_cif_writer_non_unique_labels(capsys): parser = CifParser(f"{TEST_FILES_DIR}/cif/garnet.cif") struct = parser.parse_structures()[0] - assert struct.labels[0:3] == ["Ca1", "Ca1", "Ca1"] + assert struct.labels[:3] == ["Ca1", "Ca1", "Ca1"] assert len(set(struct.labels)) != len(struct.labels) # This should raise a warning @@ -1099,7 +1099,7 @@ def test_cif_writer_non_unique_labels(capsys): CifWriter(struct) struct.relabel_sites() - assert struct.labels[0:3] == ["Ca1_1", "Ca1_2", "Ca1_3"] + assert struct.labels[:3] == ["Ca1_1", "Ca1_2", "Ca1_3"] _ = capsys.readouterr() # This should not raise a warning