diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 94dfc8f0..67a751d4 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -16,7 +16,7 @@ jobs: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: - python-version: "3.10" + python-version: "3.11" cache: pip cache-dependency-path: pyproject.toml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8e9e0cad..37c56c84 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ repos: hooks: # Run the linter. - id: ruff - args: [ --fix ] + args: [ --fix, --unsafe-fixes] # Run the formatter. - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy diff --git a/pymatgen/analysis/defects/ccd.py b/pymatgen/analysis/defects/ccd.py index 7ca30d6c..5913e7ca 100644 --- a/pymatgen/analysis/defects/ccd.py +++ b/pymatgen/analysis/defects/ccd.py @@ -19,9 +19,9 @@ from .utils import get_localized_states, get_zfile, sort_positive_definite if TYPE_CHECKING: + from collections.abc import Sequence from ctypes import Structure from pathlib import Path - from typing import Optional, Sequence, Tuple import numpy.typing as npt from matplotlib.axes import Axes @@ -66,15 +66,15 @@ class HarmonicDefect(MSONable): omega: float charge_state: int ispin: int - vrun: Optional[Vasprun] = None - distortions: Optional[Sequence[float]] = None - structures: Optional[Sequence[Structure]] = None - energies: Optional[Sequence[float]] = None - defect_band: Optional[Sequence[tuple]] = None - relaxed_index: Optional[int] = None - relaxed_bandstructure: Optional[BandStructure] = None - wswqs: Optional[list[dict]] = None - waveder: Optional[Waveder] = None + vrun: Vasprun | None = None + distortions: Sequence[float] | None = None + structures: Sequence[Structure] | None = None + energies: Sequence[float] | None = None + defect_band: Sequence[tuple] | None = None + relaxed_index: int | None = None + relaxed_bandstructure: BandStructure | None = None + wswqs: list[dict] | None = None + waveder: Waveder | None = None def __repr__(self) -> str: """String representation of the harmonic defect.""" @@ -93,7 +93,8 @@ def defect_band_index(self) -> int: """The index of the defect band.""" bands = {band for band, _, _ in self.defect_band} if len(bands) != 1: - raise ValueError("Defect band index is not unique.") + msg = "Defect band index is not unique." + raise ValueError(msg) return bands.pop() @property @@ -105,7 +106,8 @@ def spin_index(self) -> int: """ spins = {spin for _, _, spin in self.defect_band} if len(spins) != 1: - raise ValueError("Spin index is not unique.") + msg = "Spin index is not unique." + raise ValueError(msg) return spins.pop() @property @@ -113,10 +115,10 @@ def spin(self) -> Spin: """The spin of the defect returned as an Spin Enum.""" if self.spin_index == 0: return Spin.up - elif self.spin_index == 1: + if self.spin_index == 1: return Spin.down - else: - raise ValueError(f"Invalid spin index: {self.spin_index}") + msg = f"Invalid spin index: {self.spin_index}" + raise ValueError(msg) @property def relaxed_structure(self) -> Structure: @@ -164,7 +166,7 @@ def from_vaspruns( A HarmonicDefect object. """ - def _parse_vasprun(vasprun: Vasprun): + def _parse_vasprun(vasprun: Vasprun) -> tuple[float, Structure]: energy = vasprun.final_energy struct = vasprun.final_structure return (energy, struct) @@ -184,8 +186,9 @@ def _parse_vasprun(vasprun: Vasprun): ) energies, structures = list(zip(*sorted_list)) - if not np.allclose(unsorted_e, energies, atol=1e-99): - raise ValueError("The vaspruns should already be in order.") + if not np.allclose(unsorted_e, energies, atol=1e-99): # pragma: no cover + msg = "The vaspruns should already be in order." + raise ValueError(msg) omega = _get_omega( Q=distortions, @@ -196,25 +199,25 @@ def _parse_vasprun(vasprun: Vasprun): get_band_structure_kwargs = get_band_structure_kwargs or {} bandstructure = vaspruns[relaxed_index].get_band_structure( - **get_band_structure_kwargs + **get_band_structure_kwargs, ) ispin = vaspruns[relaxed_index].parameters["ISPIN"] - if store_bandstructure: - bs = bandstructure - else: - bs = None + bs = bandstructure if store_bandstructure else None if defect_band is None: if procar is None: # pragma: no cover + msg = "If defect_band_index is not provided, you must provide a Procar object." raise ValueError( - "If defect_band_index is not provided, you must provide a Procar object." + msg, ) # Get the defect bands defect_band_2s = list( get_localized_states( - bandstructure=bandstructure, procar=procar, band_window=band_window - ) + bandstructure=bandstructure, + procar=procar, + band_window=band_window, + ), ) defect_band_2s.sort(key=lambda x: (x[2], x[1])) # group by the spin index @@ -282,13 +285,17 @@ def from_directories( if charge_state is None: if vaspruns[min_idx].final_structure._charge is None: - raise ValueError( + msg = ( "Charge state is not provided and cannot be parsed from the POTCAR." ) + raise ValueError( + msg, + ) charge_state = vaspruns[0].final_structure.charge if any(v.final_structure.charge != charge_state for v in vaspruns): - raise ValueError("All vaspruns must have the same charge state.") + msg = "All vaspruns must have the same charge state." + raise ValueError(msg) return cls.from_vaspruns( vaspruns=vaspruns, @@ -315,7 +322,9 @@ def occupation(self, t: npt.ArrayLike | float) -> npt.ArrayLike: return 1.0 / (1 - np.exp(-self.omega_eV / KB * t)) def read_wswqs( - self, directory: Path, distortions: Sequence[float] | None = None + self, + directory: Path, + distortions: Sequence[float] | None = None, ) -> None: """Read the WSWQ files from a directory. @@ -327,14 +336,15 @@ def read_wswqs( distortions: The distortions used to generate the WSWQ files, if different from self.distortions """ - wswq_files = [f for f in directory.glob("WSWQ*")] + wswq_files = list(directory.glob("WSWQ*")) wswq_files.sort(key=lambda x: int(x.name.split(".")[1])) if distortions is None: distortions = self.distortions if len(wswq_files) != len(distortions): + msg = f"Number of WSWQ files ({len(wswq_files)}) does not match number of distortions ({len(distortions)})." raise ValueError( - f"Number of WSWQ files ({len(wswq_files)}) does not match number of distortions ({len(distortions)})." + msg, ) self.wswqs = [ {"Q": d, "wswq": WSWQ.from_file(f)} for d, f in zip(distortions, wswq_files) @@ -361,19 +371,23 @@ def get_elph_me(self, defect_state: tuple) -> npt.ArrayLike: The indices are [band_j,] """ if self.wswqs is None: - raise RuntimeError("WSWQs have not been read. Use `read_wswqs` first.") + msg = "WSWQs have not been read. Use `read_wswqs` first." + raise RuntimeError(msg) distortions = [wswq["Q"] for wswq in self.wswqs] wswqs = [wswq["wswq"] for wswq in self.wswqs] band_index, kpoint_index, spin_index = defect_state # The second band index is associated with the defect state # since we are usually interested in capture slopes = _get_wswq_slope(distortions, wswqs)[ - spin_index, kpoint_index, :, band_index + spin_index, + kpoint_index, + :, + band_index, ] ediffs = self._get_ediff(output_order="skb")[spin_index, kpoint_index, :] return np.multiply(slopes, ediffs) - def _get_ediff(self, output_order="skb") -> npt.NDArray: + def _get_ediff(self, output_order: str = "skb") -> npt.NDArray: """Compute the eigenvalue difference to the defect band. .. note:: @@ -389,11 +403,14 @@ def _get_ediff(self, output_order="skb") -> npt.NDArray: The eigenvalue difference to the defect band in the order specified by output_order. """ - if self.relaxed_bandstructure is None: - raise ValueError( # pragma: no cover + if self.relaxed_bandstructure is None: # pragma: no cover + msg = ( "The ``relaxed_bandstructure`` must be set before ``ediff`` can be computed. " "Try setting ``store_bandstructure=True`` when initializing." ) + raise ValueError( + msg, + ) ediffs_ = _get_ks_ediff( bandstructure=self.relaxed_bandstructure, @@ -402,21 +419,23 @@ def _get_ediff(self, output_order="skb") -> npt.NDArray: ediffs_stack = [ ediffs_[Spin.up].T, ] - if Spin.down in ediffs_.keys(): + if Spin.down in ediffs_: ediffs_stack.append(ediffs_[Spin.down].T) ediffs = np.stack(ediffs_stack) if output_order == "skb": return ediffs - elif output_order == "bks": + if output_order == "bks": return ediffs.transpose((2, 1, 0)) - else: - raise ValueError( - "Invalid output_order, choose from 'skb' or 'bks'." - ) # pragma: no cover + msg = "Invalid output_order, choose from 'skb' or 'bks'." + raise ValueError( + msg, + ) def get_dielectric_function( - self, idir: int, jdir: int + self, + idir: int, + jdir: int, ) -> tuple[npt.ArrayLike, npt.ArrayLike, npt.ArrayLike]: """Calculate the dielectric function. @@ -430,7 +449,8 @@ def get_dielectric_function( eps_cbm: The dielectric function from the defect state to the CBM. """ dfc = DielectricFunctionCalculator.from_vasp_objects( - vrun=self.vrun, waveder=self.waveder + vrun=self.vrun, + waveder=self.waveder, ) # two masks to select for VBM -> Defect and Defect -> CBM @@ -521,7 +541,8 @@ def get_SRH_coefficient( elif initial_state.charge_state == final_state.charge_state - 1: band_slice = slice(defect_band - n_band_edge, defect_band) else: - raise ValueError("SRH capture event must involve a charge state change of 1.") + msg = "SRH capture event must involve a charge state change of 1." + raise ValueError(msg) me_band_edge = me_all[band_slice] dQ = get_dQ(initial_state.relaxed_structure, final_state.relaxed_structure) @@ -539,137 +560,6 @@ def get_SRH_coefficient( ) -# @dataclass -# class RadiativeCatpture(MSONable): -# """Representation of a radiative capture event. - -# Attributes: -# initial_state: The initial state of the radiative capture event. -# final_state: The final state of the radiative capture event. -# dQ: The configuration coordinate change between the relaxed initial state and the final state. -# waveder: The data from the WAVEDER file obtained with ``LOPTICS=.True.``. - -# """ - -# initial_state: HarmonicDefect -# final_state: HarmonicDefect -# dQ: float -# waveder: Waveder - -# def get_coeff( -# self, -# T: float | npt.ArrayLike, -# dE: float, -# omega_photon: float, -# volume: float | None = None, -# g: int = 1, -# occ_tol: float = 1e-3, -# n_band_edge: int = 1, -# ): -# """Calculate the SRH recombination coefficient.""" -# if volume is None: -# volume = self.initial_state.relaxed_structure.volume - -# me_all = self.get_dipoles() # indices: [band, kpoint, spin, coord] - -# istate = self.initial_state - -# if self.initial_state.charge_state == self.final_state.charge_state + 1: -# band_slice = slice( -# istate.defect_band_index + 1, istate.defect_band_index + 1 + n_band_edge -# ) -# elif self.initial_state.charge_state == self.final_state.charge_state - 1: -# band_slice = slice( -# istate.defect_band_index - n_band_edge, istate.defect_band_index -# ) -# else: -# raise ValueError( -# "SRH capture event must involve a charge state change of 1." -# ) - -# me_band_edge = me_all[band_slice, istate.kpt_index, istate.spin_index] - -# return get_Rad_coef( -# T, -# dQ=self.dQ, -# dE=dE, -# omega_i=self.initial_state.omega_eV, -# omega_f=self.final_state.omega_eV, -# omega_photon=omega_photon, -# dipole_me=np.average(me_band_edge), -# volume=volume, -# g=g, -# occ_tol=occ_tol, -# ) - -# @classmethod -# def from_directories( -# cls, -# initial_dirs: list[Path], -# final_dirs: list[Path], -# waveder_dir: Path, -# kpt_index: int, -# initial_charge_state: int | None = None, -# final_charge_state: int | None = None, -# spin_index: int | None = None, -# defect_band_index: int | None = None, -# store_bandstructure: bool = False, -# get_band_structure_kwargs: dict | None = None, -# **kwargs, -# ) -> RadiativeCatpture: -# """Create a RadiativeCapture object from a list of directories. - -# Args: -# initial_dirs: A list of directories for the initial state. -# final_dirs: A list of directories for the final state. -# waveder_dir: The directory containing the WAVEDER file. -# kpt_index: The index of the k-point to use. -# initial_charge_state: The charge state of the initial state. -# If None, the charge state is determined from the vasprun.xml and POTCAR files. -# final_charge_state: The charge state of the final state. -# If None, the charge state is determined from the vasprun.xml and POTCAR files. -# spin_index: The index of the spin channel to use. -# If None, the spin channel is determined by the channel with the most localized state. -# defect_band_index: The index of the defect band (0-indexed). -# If None, the defect band is determined by the band with the most localized state. -# store_bandstructure: Whether to store the band structure. -# get_band_structure_kwargs: Keyword arguments to pass to get_band_structure. -# **kwargs: Keyword arguments to pass to the HarmonicDefect constructor. - -# Returns: -# A SRHCapture object. -# """ -# initial_defect = HarmonicDefect.from_directories( -# directories=initial_dirs, -# charge_state=initial_charge_state, -# spin_index=spin_index, -# relaxed_index=None, -# defect_band_index=defect_band_index, -# store_bandstructure=store_bandstructure, -# get_band_structure_kwargs=get_band_structure_kwargs, -# **kwargs, -# ) - -# # the final state does not need the additional -# # information about the electronic states -# final_defect = HarmonicDefect.from_directories( -# directories=final_dirs, -# kpt_index=kpt_index, -# charge_state=final_charge_state, -# spin_index=spin_index, -# relaxed_index=None, -# defect_band_index=None, -# store_bandstructure=None, -# get_band_structure_kwargs=None, -# **kwargs, -# ) - -# waveder_file = get_zfile(waveder_dir, "WAVEDER") -# waveder = Waveder(waveder_file) -# dQ = get_dQ(initial_defect.relaxed_structure, final_defect.relaxed_structure) -# return cls(initial_defect, final_defect, dQ=dQ, waveder=waveder) - - def get_dQ(ground: Structure, excited: Structure) -> float: """Calculate configuration coordinate difference. @@ -682,13 +572,11 @@ def get_dQ(ground: Structure, excited: Structure) -> float: """ return np.sqrt( np.sum( - list( - map( - lambda x: x[0].distance(x[1]) ** 2 * x[0].specie.atomic_mass, - zip(ground, excited), - ) - ) - ) + [ + x[0].distance(x[1]) ** 2 * x[0].specie.atomic_mass + for x in zip(ground, excited) + ], + ), ) @@ -716,11 +604,14 @@ def _get_omega( def _fit_parabola( - Q: npt.ArrayLike, energy: npt.ArrayLike, Q0: float, E0: float -) -> Tuple[float, float, float]: + Q: npt.ArrayLike, + energy: npt.ArrayLike, + Q0: float, + E0: float, +) -> tuple[float, float, float]: """Fit the parabola to the data.""" - def f(Q, omega): + def f(Q: float, omega: float) -> float: """Get the parabola function.""" return 0.5 * omega**2 * (Q - Q0) ** 2 + E0 @@ -743,7 +634,7 @@ def _get_wswq_slope(distortions: list[float], wswqs: list[WSWQ]) -> npt.NDArray: yy = np.stack([np.abs(ww.data) * np.sign(qq) for qq, ww in zip(distortions, wswqs)]) _, *oldshape = yy.shape return np.polyfit(distortions, yy.reshape(yy.shape[0], -1), deg=1)[0].reshape( - *oldshape + *oldshape, ) @@ -764,7 +655,7 @@ def _get_ks_ediff( npt.NDArray: The Kohn-Sham energy difference between the defect state and other states. Indexed the same way as ``bandstructure.bands``. """ - res = dict() + res = {} b_at_kpt_and_spin = {(k, s): b for b, k, s in defect_band} for ispin, eigs in bandstructure.bands.items(): spin_index = 0 if ispin == Spin.up else 1 @@ -780,7 +671,11 @@ def _get_ks_ediff( def plot_pes( - hd: HarmonicDefect, x_shift=0, y_shift=0, width: float = 1.0, ax: Axes = None + hd: HarmonicDefect, + x_shift: float = 0, + y_shift: float = 0, + width: float = 1.0, + ax: Axes = None, ) -> None: """Plot the Potential Energy Surface of a HarmonicDefect. diff --git a/pymatgen/analysis/defects/core.py b/pymatgen/analysis/defects/core.py index a6708491..43afaf36 100644 --- a/pymatgen/analysis/defects/core.py +++ b/pymatgen/analysis/defects/core.py @@ -6,7 +6,7 @@ import logging from abc import ABCMeta, abstractmethod, abstractproperty from enum import Enum -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING import numpy as np from monty.json import MSONable @@ -15,6 +15,7 @@ from pymatgen.core import Element, PeriodicSite, Species from pymatgen.core.periodic_table import DummySpecies from pymatgen.symmetry.analyzer import SpacegroupAnalyzer +from typing_extensions import Self from .utils import get_plane_spacing @@ -36,6 +37,8 @@ _logger = logging.getLogger(__name__) +RNG = np.random.default_rng(42) + class DefectType(Enum): """Defect type, for sorting purposes.""" @@ -92,7 +95,7 @@ def __init__( # check oxi_states assigned and not all zero if all(specie.oxi_state == 0 for specie in self.structure.species): self.structure.add_oxidation_state_by_guess() - except Exception: + except Exception: # noqa: BLE001 # pragma: no cover self.structure.add_oxidation_state_by_guess() self.oxi_state = self._guess_oxi_state() else: @@ -127,7 +130,7 @@ def defect_structure(self) -> Structure: """Get the unit-cell structure representing the defect.""" @abstractproperty - def element_changes(self) -> Dict[Element, int]: + def element_changes(self) -> dict[Element, int]: """Get the species changes of the defect. Returns: @@ -166,9 +169,11 @@ def get_charge_states(self, padding: int = 1) -> list[int]: else: # pragma: no cover sign = -1 if self.oxi_state < 0 else 1 oxi_state = sign * int(np.ceil(abs(self.oxi_state))) - _logger.warn( + _logger.warning( "Non-integer oxidation state detected." - f"Rounding to integer with larger absolute value: {self.oxi_state} -> {oxi_state}" + "Round to integer with larger absolute value: %s -> %s", + self.oxi_state, + oxi_state, ) if oxi_state >= 0: @@ -232,8 +237,8 @@ def get_supercell_structure( PeriodicSite (optional): The position of the defect site in the supercell. """ - def _has_oxi(struct): - return all([hasattr(site.specie, "oxi_state") for site in struct]) + def _has_oxi(struct: Structure) -> bool: + return all(hasattr(site.specie, "oxi_state") for site in struct) if defect_structure is None: defect_structure = self.centered_defect_structure @@ -278,7 +283,7 @@ def _has_oxi(struct): # interstitials int_uc_indices = set(range(len(defect_structure))) - set( - defect_site_mapping.keys() + defect_site_mapping.keys(), ) for i in int_uc_indices: int_sc_pos = np.dot(defect_structure[i].frac_coords, sc_mat_inv) @@ -321,19 +326,22 @@ def _has_oxi(struct): def symmetrized_structure(self) -> SymmetrizedStructure: """Get the symmetrized version of the bulk structure.""" sga = SpacegroupAnalyzer( - self.structure, symprec=self.symprec, angle_tolerance=self.angle_tolerance + self.structure, + symprec=self.symprec, + angle_tolerance=self.angle_tolerance, ) return sga.get_symmetrized_structure() def __eq__(self, __o: object) -> bool: """Equality operator.""" if not isinstance(__o, Defect): # pragma: no cover - raise TypeError("Can only compare Defects to Defects") + msg = "Can only compare Defects to Defects" + raise TypeError(msg) sm = StructureMatcher(comparator=ElementComparator()) return sm.fit(self.defect_structure, __o.defect_structure) @property - def defect_type(self) -> int: + def defect_type(self) -> DefectType: """Get the defect type. Returns: @@ -380,7 +388,9 @@ def __init__(self, name: str, bulk_formula: str, element_changes: dict) -> None: self.element_changes = element_changes @classmethod - def from_structures(cls, defect_structure: Structure, bulk_structure: Structure): + def from_structures( + cls, defect_structure: Structure, bulk_structure: Structure + ) -> Self: """Initialize a NameDefect object from structures. Args: @@ -412,7 +422,8 @@ def latex_name(self) -> str: def __eq__(self, __value: object) -> bool: """Only need to compare names.""" if not isinstance(__value, NamedDefect): # pragma: no cover - raise TypeError("Can only compare NamedDefects to NamedDefects") + msg = "Can only compare NamedDefects to NamedDefects" + raise TypeError(msg) return self.__repr__() == __value.__repr__() def __repr__(self) -> str: @@ -443,15 +454,16 @@ def name(self) -> str: return f"v_{get_element(self.defect_site.specie)}" @property - def defect_site(self): + def defect_site(self) -> PeriodicSite: """Returns the site in the structure that corresponds to the defect site.""" - res = min( + return min( self.structure.get_sites_in_sphere( - self.site.coords, 0.1, include_index=True + self.site.coords, + 0.1, + include_index=True, ), key=lambda x: x[1], ) - return res @property def defect_site_index(self) -> int: @@ -459,14 +471,14 @@ def defect_site_index(self) -> int: return self.defect_site.index @property - def defect_structure(self): + def defect_structure(self) -> Structure: """Returns the defect structure with the proper oxidation state.""" struct = self.structure.copy() struct.remove_sites([self.defect_site_index]) return struct @property - def element_changes(self) -> Dict[Element, int]: + def element_changes(self) -> dict[Element, int]: """Get the species changes of the vacancy defect. Returns: @@ -547,11 +559,13 @@ def defect_structure(self) -> Structure: return struct @property - def defect_site(self): + def defect_site(self) -> PeriodicSite: """Returns the site in the structure that corresponds to the defect site.""" return min( self.structure.get_sites_in_sphere( - self.site.coords, 0.1, include_index=True + self.site.coords, + 0.1, + include_index=True, ), key=lambda x: x[1], ) @@ -562,7 +576,7 @@ def defect_site_index(self) -> int: return self.defect_site.index @property - def element_changes(self) -> Dict[Element, int]: + def element_changes(self) -> dict[Element, int]: """Get the species changes of the substitution defect. Returns: @@ -593,15 +607,18 @@ def _guess_oxi_state(self) -> float: ] if len(sub_elt_sites_in_struct) == 0: sub_states = self.site.specie.common_oxidation_states - if len(sub_states) == 0: - raise ValueError( + if len(sub_states) == 0: # pragma: no cover + msg = ( f"No common oxidation states found for {self.site.specie}." "Please specify the oxidation state manually." ) + raise ValueError( + msg, + ) sub_oxi = min(sub_states, key=lambda x: abs(x - rm_oxi)) else: sub_oxi = int( - np.mean([site.specie.oxi_state for site in sub_elt_sites_in_struct]) + np.mean([site.specie.oxi_state for site in sub_elt_sites_in_struct]), ) return sub_oxi - rm_oxi @@ -642,13 +659,19 @@ def __init__( **kwargs: Additional kwargs to pass to the Defect constructor. """ super().__init__( - structure, site, multiplicity, oxi_state, equivalent_sites, **kwargs + structure, + site, + multiplicity, + oxi_state, + equivalent_sites, + **kwargs, ) def get_multiplicity(self) -> int: """Determine the multiplicity of the defect site within the structure.""" + msg = "Interstitial multiplicity should be determined by the generator." raise NotImplementedError( - "Interstitial multiplicity should be determined by the generator." + msg, ) @property @@ -665,8 +688,9 @@ def defect_structure(self) -> Structure: inter_states = self.site.specie.icsd_oxidation_states[:2] if len(inter_states) == 0: _logger.warning( - f"No oxidation states found for {self.site.specie.symbol}. " - "in ICSD using `oxidation_states` without frequency ranking." + "No oxidation states found for %s. " + "in ICSD using `oxidation_states` without frequency ranking.", + self.site.specie.symbol, ) inter_states = self.site.specie.oxidation_states inter_oxi = max(inter_states, key=abs) @@ -684,7 +708,7 @@ def defect_site_index(self) -> int: return 0 @property - def element_changes(self) -> Dict[Element, int]: + def element_changes(self) -> dict[Element, int]: """Get the species changes of the intersitial defect. Returns: @@ -748,7 +772,8 @@ def __repr__(self) -> str: def __eq__(self, __o: object) -> bool: """Check if are equal.""" if not isinstance(__o, Defect): - raise TypeError("Can only compare Defects to Defects") + msg = "Can only compare Defects to Defects" + raise TypeError(msg) sm = StructureMatcher(comparator=ElementComparator()) this_structure = self.defect_structure_with_com if isinstance(__o, DefectComplex): @@ -766,12 +791,13 @@ def defect_structure_with_com(self) -> Structure: def get_multiplicity(self) -> int: """Determine the multiplicity of the defect site within the structure.""" + msg = "Not implemented for defect complexes" raise NotImplementedError( - "Not implemented for defect complexes" + msg, ) # pragma: no cover @property - def element_changes(self) -> Dict[Element, int]: + def element_changes(self) -> dict[Element, int]: """Determine the species changes of the complex defect.""" cnt: dict[Element, int] = collections.defaultdict(int) for defect in self.defects: @@ -796,7 +822,9 @@ def defect_structure(self) -> Structure: defect_structure = self.structure.copy() for defect in self.defects: update_structure( - defect_structure, defect.site, defect_type=defect.defect_type + defect_structure, + defect.site, + defect_type=defect.defect_type, ) return defect_structure @@ -807,7 +835,9 @@ def latex_name(self) -> str: return "$+$".join(single_names) -def update_structure(structure, site, defect_type): +def update_structure( + structure: Structure, site: PeriodicSite, defect_type: DefectType +) -> None: """Update the structure with the defect site. Types of operations: @@ -824,11 +854,14 @@ def update_structure(structure, site, defect_type): Structure: The updated structure. """ - def _update(structure, site, rm: bool, replace: bool): + def _update( + structure: Structure, site: PeriodicSite, rm: bool, replace: bool + ) -> None: in_sphere = structure.get_sites_in_sphere(site.coords, 0.1, include_index=True) if len(in_sphere) == 0 and rm: # pragma: no cover - raise ValueError("No site found to remove.") + msg = "No site found to remove." + raise ValueError(msg) if rm or replace: rm_site = min( @@ -855,7 +888,8 @@ def _update(structure, site, rm: bool, replace: bool): elif defect_type == DefectType.Interstitial: _update(structure, site, rm=False, replace=False) else: - raise ValueError("Unknown point defect type.") # pragma: no cover + msg = "Unknown point defect type." + raise ValueError(msg) # pragma: no cover class Adsorbate(Interstitial): @@ -899,8 +933,10 @@ def get_vacancy(structure: Structure, isite: int, **kwargs) -> Vacancy: def _set_selective_dynamics( - structure: Structure, site_pos: ArrayLike, relax_radius: float | str | None -): + structure: Structure, + site_pos: ArrayLike, + relax_radius: float | str | None, +) -> None: """Set the selective dynamics behavior. Allow atoms to move for sites within a given radius of a given site, @@ -916,10 +952,13 @@ def _set_selective_dynamics( if relax_radius == "auto": relax_radius = min(get_plane_spacing(structure.lattice.matrix)) / 2.0 if not isinstance(relax_radius, float): - raise ValueError("relax_radius must be a float or 'auto' or None") + msg = "relax_radius must be a float or 'auto' or None" + raise ValueError(msg) structure.get_sites_in_sphere(site_pos, relax_radius) relax_sites = structure.get_sites_in_sphere( - site_pos, relax_radius, include_index=True + site_pos, + relax_radius, + include_index=True, ) relax_indices = [site.index for site in relax_sites] relax_mask = [[False, False, False]] * len(structure) @@ -929,7 +968,7 @@ def _set_selective_dynamics( def perturb_sites( - structure, + structure: Structure, distance: float, min_distance: float | None = None, site_indices: list | None = None, @@ -952,13 +991,13 @@ def perturb_sites( """ - def get_rand_vec(): + def get_rand_vec() -> ArrayLike: # deals with zero vectors. - vector = np.random.randn(3) + vector = RNG.normal(size=3) vnorm = np.linalg.norm(vector) dist = distance if isinstance(min_distance, (float, int)): - dist = np.random.uniform(min_distance, dist) + dist = RNG.uniform(min_distance, dist) return vector / vnorm * dist if vnorm != 0 else get_rand_vec() if site_indices is None: @@ -970,7 +1009,7 @@ def get_rand_vec(): structure.translate_sites([i], get_rand_vec(), frac_coords=False) -def _perturb_dynamic_sites(structure, distance): +def _perturb_dynamic_sites(structure: Structure, distance: float) -> None: free_indices = [ i for i, site in enumerate(structure) @@ -979,7 +1018,9 @@ def _perturb_dynamic_sites(structure, distance): perturb_sites(structure=structure, distance=distance, site_indices=free_indices) -def _get_mapped_sites(uc_structure: Structure, sc_structure: Structure, r=0.001): +def _get_mapped_sites( + uc_structure: Structure, sc_structure: Structure, r: float = 0.001 +) -> dict: """Get the list of sites indices in the supercell corresponding to the unit cell.""" mapped_site_indices = {} for isite, uc_site in enumerate(uc_structure): @@ -1019,7 +1060,7 @@ def _get_el_changes_from_structures(defect_sc: Structure, bulk_sc: Structure) -> dict: A dictionary representing the species changes in creating the defect. """ - def _check_int(n): + def _check_int(n: float) -> bool: return isinstance(n, int) or n.is_integer() comp_defect = defect_sc.composition.element_composition @@ -1030,8 +1071,9 @@ def _check_int(n): for el, cnt in comp_defect.items(): # has to be integer if not (_check_int(comp_bulk[el]) and _check_int(cnt)): + msg = "Defect structure and bulk structure must have integer compositions." raise ValueError( - "Defect structure and bulk structure must have integer compositions." + msg, ) tmp_ = int(cnt) - int(comp_bulk[el]) if tmp_ != 0: diff --git a/pymatgen/analysis/defects/corrections/freysoldt.py b/pymatgen/analysis/defects/corrections/freysoldt.py index 97389944..4b433ebd 100644 --- a/pymatgen/analysis/defects/corrections/freysoldt.py +++ b/pymatgen/analysis/defects/corrections/freysoldt.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import matplotlib.pyplot as plt import numpy as np @@ -22,6 +22,7 @@ from scipy import stats if TYPE_CHECKING: + from matplotlib.axes import Axes from numpy.typing import ArrayLike from pymatgen.core import Lattice @@ -44,11 +45,11 @@ def get_freysoldt_correction( dielectric: float, defect_locpot: Locpot, bulk_locpot: Locpot, - defect_frac_coords: Optional[ArrayLike] = None, - lattice: Optional[Lattice] = None, + defect_frac_coords: ArrayLike | None = None, + lattice: Lattice | None = None, energy_cutoff: float = 520, mad_tol: float = 1e-4, - q_model: Optional[QModel] = None, + q_model: QModel | None = None, step: float = 1e-4, ) -> CorrectionResult: """Gets the Freysoldt correction for a defect entry. @@ -97,9 +98,10 @@ def get_freysoldt_correction( dielectric = float(np.mean(dielectric)) elif np.ndim(dielectric) == 2: # pragma: no cover dielectric = float(np.mean(dielectric.diagonal())) - else: + else: # pragma: no cover + msg = f"Dielectric constant cannot be converted into a scalar. Currently of type {type(dielectric)}" raise ValueError( - f"Dielectric constant is cannot be converted into a scalar. Currently of type {type(dielectric)}" + msg, ) q_model = QModel() if q_model is None else q_model @@ -107,12 +109,13 @@ def get_freysoldt_correction( if isinstance(defect_locpot, VolumetricData): list_axis_grid = [*map(defect_locpot.get_axis_grid, [0, 1, 2])] list_defect_plnr_avg_esp = [ - *map(defect_locpot.get_average_along_axis, [0, 1, 2]) + *map(defect_locpot.get_average_along_axis, [0, 1, 2]), ] lattice_ = defect_locpot.structure.lattice.copy() if lattice is not None and lattice != lattice_: + msg = "Lattice of defect_locpot and user provided lattice do not match." raise ValueError( - "Lattice of defect_locpot and user provided lattice do not match." + msg, ) lattice = lattice_ elif isinstance(defect_locpot, dict): @@ -124,19 +127,20 @@ def get_freysoldt_correction( [0, 0, 0], lattice.abc, [len(i) for i in list_defect_plnr_avg_esp], - ) + ), ] else: - raise ValueError("defect_locpot must be of type Locpot or dict") + msg = "defect_locpot must be of type Locpot or dict" + raise ValueError(msg) - # TODO this can be done with regridding later if isinstance(bulk_locpot, VolumetricData): list_bulk_plnr_avg_esp = [*map(bulk_locpot.get_average_along_axis, [0, 1, 2])] elif isinstance(bulk_locpot, dict): bulk_locpot_ = {int(k): v for k, v in bulk_locpot.items()} list_bulk_plnr_avg_esp = [bulk_locpot_[i] for i in range(3)] - else: - raise ValueError("bulk_locpot must be of type Locpot or dict") + else: # pragma: no cover + msg = "bulk_locpot must be of type Locpot or dict" + raise ValueError(msg) es_corr = perform_es_corr( lattice=lattice, @@ -148,11 +152,14 @@ def get_freysoldt_correction( step=step, ) - alignment_corrs = dict() - plot_data = dict() + alignment_corrs = {} + plot_data = {} for x, pureavg, defavg, axis in zip( - list_axis_grid, list_bulk_plnr_avg_esp, list_defect_plnr_avg_esp, [0, 1, 2] + list_axis_grid, + list_bulk_plnr_avg_esp, + list_defect_plnr_avg_esp, + [0, 1, 2], ): alignment_corr, md = perform_pot_corr( axis_grid=x, @@ -186,7 +193,13 @@ def get_freysoldt_correction( def perform_es_corr( - lattice, q, dielectric, q_model, energy_cutoff=520, mad_tol=1e-4, step=1e-4 + lattice: Lattice, + q: float, + dielectric: float, + q_model: QModel, + energy_cutoff: float = 520, + mad_tol: float = 1e-4, + step: float = 1e-4, ) -> float: """Perform Electrostatic Freysoldt Correction. @@ -206,15 +219,15 @@ def perform_es_corr( Electrostatic Point Charge contribution to Freysoldt Correction (float) """ _logger.info( - "Running Freysoldt 2011 PC calculation (should be equivalent to sxdefectalign)" + "Running Freysoldt 2011 PC calculation (should be equivalent to sxdefectalign)", ) - _logger.debug("defect lattice constants are (in angstroms)" + str(lattice.abc)) + _logger.debug("defect lattice constants are (in angstroms) %s", str(lattice.abc)) [a1, a2, a3] = ang_to_bohr * np.array(lattice.get_cartesian_coords(1)) - logging.debug("In atomic units, lat consts are (in bohr):" + str([a1, a2, a3])) + logging.debug("In atomic units, lat consts are (in bohr): %s", str([a1, a2, a3])) vol = np.dot(a1, np.cross(a2, a3)) # vol in bohr^3 - def e_iso(encut): + def e_iso(encut: float) -> float: gcut = eV_to_k(encut) # gcut is in units of 1/A return ( scipy.integrate.quad(lambda g: q_model.rho_rec(g * g) ** 2, step, gcut)[0] @@ -222,8 +235,8 @@ def e_iso(encut): / np.pi ) - def e_per(encut): - eper = 0 + def e_per(encut: float) -> float: + eper = 0.0 for g2 in generate_reciprocal_vectors_squared(a1, a2, a3, encut): eper += (q_model.rho_rec(g2) ** 2) / g2 eper *= (q**2) * 2 * round(np.pi, 6) / vol @@ -245,18 +258,18 @@ def e_per(encut): def perform_pot_corr( - axis_grid, - pureavg, - defavg, - lattice, - q, - defect_frac_coords, - axis, - dielectric, - q_model, - mad_tol=1e-4, - widthsample=1.0, -): + axis_grid: ArrayLike, + pureavg: ArrayLike, + defavg: ArrayLike, + lattice: Lattice, + q: float, + defect_frac_coords: ArrayLike, + axis: Axes, + dielectric: float, + q_model: QModel, + mad_tol: float = 1e-4, + widthsample: float = 1.0, +) -> tuple[float, dict]: """For performing planar averaging potential alignment. Args: @@ -290,7 +303,7 @@ def perform_pot_corr( (float) Potential Alignment shift required to make the short range potential zero far from the defect. (-C) in the Freysoldt paper. """ - logging.debug("run Freysoldt potential alignment method for axis " + str(axis)) + logging.debug("run Freysoldt potential alignment method for axis %s", str(axis)) nx = len(axis_grid) # shift these planar averages to have defect at origin @@ -327,7 +340,8 @@ def perform_pot_corr( v_R = np.fft.fft(v_G) if abs(np.imag(v_R).max()) > mad_tol: - raise Exception("imaginary part found to be %s", repr(np.imag(v_R).max())) + msg = "imaginary part found to be %s" + raise Exception(msg, repr(np.imag(v_R).max())) v_R /= lattice.volume * ang_to_bohr**3 v_R = np.real(v_R) * hart_to_ev @@ -351,13 +365,13 @@ def perform_pot_corr( C = np.mean(tmppot) _logger.debug("C = %f", C) short_range = short - v_R = [elmnt for elmnt in v_R] + v_R = list(v_R) _logger.info("C value is averaged to be %f eV ", C) _logger.info("Potentital alignment energy correction (q * Delta): %f (eV)", q * C) # log plotting data: - metadata = dict() + metadata = {} metadata["pot_plot_data"] = { "Vr": v_R, "x": axis_grid, @@ -374,7 +388,9 @@ def perform_pot_corr( return C, metadata -def plot_plnr_avg(plot_data, title=None, saved=False, ax=None): +def plot_plnr_avg( + plot_data: dict, title: str | None = None, saved: bool = False, ax: Axes = None +) -> Axes: """Plot the planar average electrostatic potential. Plot the planar average electrostatic potential against the Long range and @@ -388,8 +404,9 @@ def plot_plnr_avg(plot_data, title=None, saved=False, ax=None): If True then saves plot as str(title) + "FreyplnravgPlot.pdf" ax (matplotlib.axes.Axes): Axes object to plot on. If None, makes new figure. """ - if not plot_data["pot_plot_data"]: - raise ValueError("Cannot plot potential alignment before running correction!") + if not plot_data["pot_plot_data"]: # pragma: no cover + msg = "Cannot plot potential alignment before running correction!" + raise ValueError(msg) x = plot_data["pot_plot_data"]["x"] v_R = plot_data["pot_plot_data"]["Vr"] @@ -408,7 +425,12 @@ def plot_plnr_avg(plot_data, title=None, saved=False, ax=None): tmpx = [x[i] for i in range(check[0], check[1])] ax.fill_between( - tmpx, -100, 100, facecolor="red", alpha=0.15, label="sampling region" + tmpx, + -100, + 100, + facecolor="red", + alpha=0.15, + label="sampling region", ) ax.set_xlim(round(x[0]), round(x[-1])) @@ -419,9 +441,9 @@ def plot_plnr_avg(plot_data, title=None, saved=False, ax=None): ax.set_ylabel("Potential (V)", fontsize=15) ax.legend(loc=9) ax.axhline(y=0, linewidth=0.2, color="black") - if title is not None: + if title is not None: # pragma: no cover ax.set_title(str(title), fontsize=18) ax.set_xlim(0, max(x)) - if saved: + if saved: # pragma: no cover fig.savefig(str(title) + "FreyplnravgPlot.pdf") return ax diff --git a/pymatgen/analysis/defects/corrections/kumagai.py b/pymatgen/analysis/defects/corrections/kumagai.py index acd48a05..18d99130 100644 --- a/pymatgen/analysis/defects/corrections/kumagai.py +++ b/pymatgen/analysis/defects/corrections/kumagai.py @@ -38,11 +38,12 @@ logging.getLogger("pydefect").setLevel(logging.WARNING) -def _check_import_pydefect(): +def _check_import_pydefect() -> None: """Import pydefect if it is installed.""" if __has_pydefect__: + msg = "vise/pydefect is not installed. Please install it first." raise ModuleNotFoundError( - "vise/pydefect is not installed. Please install it first." + msg, ) @@ -71,8 +72,7 @@ def get_structure_with_pot(directory: Path) -> Structure: ionic_conv=vasprun.converged_ionic, ) - struct = calc.structure.copy(site_properties={"potential": calc.potentials}) - return struct + return calc.structure.copy(site_properties={"potential": calc.potentials}) def get_efnv_correction( @@ -118,5 +118,6 @@ def get_efnv_correction( ) return CorrectionResult( - correction_energy=efnv_corr.correction_energy, metadata={"efnv_corr": efnv_corr} + correction_energy=efnv_corr.correction_energy, + metadata={"efnv_corr": efnv_corr}, ) diff --git a/pymatgen/analysis/defects/finder.py b/pymatgen/analysis/defects/finder.py index 92173e1d..709402a4 100644 --- a/pymatgen/analysis/defects/finder.py +++ b/pymatgen/analysis/defects/finder.py @@ -3,9 +3,7 @@ from __future__ import annotations import logging -import warnings -from collections import namedtuple -from typing import TYPE_CHECKING, List, Tuple +from typing import TYPE_CHECKING, NamedTuple import numpy as np from monty.json import MSONable @@ -31,14 +29,27 @@ _logger = logging.getLogger(__name__) DUMMY_SPECIES = "Si" -SiteVec = namedtuple("SiteVec", ["species", "site", "vec"]) -SiteGroup = namedtuple("SiteGroup", ["species", "similar_sites", "vec"]) + +class SiteVec(NamedTuple): + """NamedTuple representing a site in the defect structure.""" + + species: str + site: Structure + vec: NDArray + + +class SiteGroup(NamedTuple): + """NamedTuple representing a group of symmetrically equivalent sites.""" + + species: str + similar_sites: list[int] + vec: NDArray class DefectSiteFinder(MSONable): """Find the location of a defect with no pior knowledge.""" - def __init__(self, symprec: float = 0.01, angle_tolerance: float = 5.0): + def __init__(self, symprec: float = 0.01, angle_tolerance: float = 5.0) -> None: """Configure the behavior of the defect site finder. Args: @@ -46,18 +57,19 @@ def __init__(self, symprec: float = 0.01, angle_tolerance: float = 5.0): angle_tolerance (float): Angle tolerance parameter for SpacegroupAnalyzer """ if SOAP is None: + msg = "dscribe is required to use DefectSiteFinder. Install with ``pip install dscribe``." raise ImportError( - "dscribe is required to use DefectSiteFinder. Install with ``pip install dscribe``." + msg, ) self.symprec = symprec self.angle_tolerance = angle_tolerance def get_defect_fpos( self, - defect_structure: "Structure", - base_structure: "Structure", + defect_structure: Structure, + base_structure: Structure, remove_oxi: bool = True, - ) -> "ArrayLike": + ) -> ArrayLike: """Get the position of a defect in the pristine structure. Args: @@ -75,11 +87,12 @@ def get_defect_fpos( if self._is_impurity(defect_structure, base_structure): return self.get_impurity_position(defect_structure, base_structure) - else: - return self.get_native_defect_position(defect_structure, base_structure) + return self.get_native_defect_position(defect_structure, base_structure) def _is_impurity( - self, defect_structure: "Structure", base_structure: "Structure" + self, + defect_structure: Structure, + base_structure: Structure, ) -> bool: """Check if the defect structure is an impurity. @@ -96,8 +109,10 @@ def _is_impurity( return len(defect_species - base_species) > 0 def get_native_defect_position( - self, defect_structure: "Structure", base_structure: "Structure" - ) -> "ArrayLike": + self, + defect_structure: Structure, + base_structure: Structure, + ) -> ArrayLike: """Get the position of a native defect in the defect structure. Args: @@ -109,16 +124,20 @@ def get_native_defect_position( (in fractional coordinates) """ distored_sites, distortions = list( - zip(*self.get_most_distorted_sites(defect_structure, base_structure)) + zip(*self.get_most_distorted_sites(defect_structure, base_structure)), ) positions = [defect_structure[isite].frac_coords for isite in distored_sites] return get_weighted_average_position( - defect_structure.lattice, positions, distortions + defect_structure.lattice, + positions, + distortions, ) def get_impurity_position( - self, defect_structure: "Structure", base_structure: "Structure" - ): + self, + defect_structure: Structure, + base_structure: Structure, + ) -> ArrayLike: """Get the position of an impurity defect. Look at all sites with impurity atoms, and take the average of the positions of @@ -134,15 +153,18 @@ def get_impurity_position( # get the pbc average position of all sites not in the base structure base_species = {site.species_string for site in base_structure} impurity_sites = [ - *filter(lambda x: x.species_string not in base_species, defect_structure) + *filter(lambda x: x.species_string not in base_species, defect_structure), ] return get_weighted_average_position( - defect_structure.lattice, [s.frac_coords for s in impurity_sites] + defect_structure.lattice, + [s.frac_coords for s in impurity_sites], ) def get_most_distorted_sites( - self, defect_structure: "Structure", base_structure: "Structure" - ) -> List[Tuple[int, float]]: + self, + defect_structure: Structure, + base_structure: Structure, + ) -> list[tuple[int, float]]: """Identify the set of sites with the most deviation from the pristine. Performs the following steps: @@ -175,9 +197,9 @@ def get_most_distorted_sites( best_s, ) = best_match(v, pristine_groups) if v.species != best_m.species: - warnings.warn( + _logger.warning( "The species of a site in the distorted structure is different " - "from the species of the closest pristine site." + "from the species of the closest pristine site.", ) res.append((i, np.abs(best_s - 1))) @@ -189,7 +211,9 @@ def get_most_distorted_sites( # %% -def get_site_groups(struct, symprec=0.01, angle_tolerance=5.0) -> List[SiteGroup]: +def get_site_groups( + struct: Structure, symprec: float = 0.01, angle_tolerance: float = 5.0 +) -> list[SiteGroup]: """Group the sites in the structure by symmetry. Group the sites in the structure by symmetry and return a @@ -212,13 +236,15 @@ def get_site_groups(struct, symprec=0.01, angle_tolerance=5.0) -> List[SiteGroup soap_vec = get_soap_vec(struct) for g in groups: sg = SiteGroup( - species=sstruct[g[0]].species_string, similar_sites=g, vec=soap_vec[g[0]] + species=sstruct[g[0]].species_string, + similar_sites=g, + vec=soap_vec[g[0]], ) site_groups.append(sg) return site_groups -def get_soap_vec(struct: "Structure") -> NDArray: +def get_soap_vec(struct: Structure) -> NDArray: """Get the SOAP vector for each site in the structure. Args: @@ -234,11 +260,10 @@ def get_soap_vec(struct: "Structure") -> NDArray: for el in species_: dummy_structure.replace_species({str(el): DUMMY_SPECIES}) soap_desc = SOAP(species=[DUMMY_SPECIES], r_cut=5, n_max=8, l_max=6, periodic=True) - vecs = soap_desc.create(adaptor.get_atoms(dummy_structure)) - return vecs + return soap_desc.create(adaptor.get_atoms(dummy_structure)) -def get_site_vecs(struct: Structure) -> List[SiteVec]: +def get_site_vecs(struct: Structure) -> list[SiteVec]: """Get the SiteVec representation of each site in the structure. Args: @@ -254,7 +279,7 @@ def get_site_vecs(struct: Structure) -> List[SiteVec]: ] -def cosine_similarity(vec1, vec2) -> float: +def cosine_similarity(vec1: ArrayLike, vec2: ArrayLike) -> float: """Cosine similarity between two vectors. Args: @@ -267,7 +292,7 @@ def cosine_similarity(vec1, vec2) -> float: return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)) -def best_match(sv: SiteVec, sgs: List[SiteGroup]) -> Tuple[SiteGroup, float]: +def best_match(sv: SiteVec, sgs: list[SiteGroup]) -> tuple[SiteGroup, float]: """Find the best match for a site in the defect structure. Args: @@ -290,11 +315,12 @@ def best_match(sv: SiteVec, sgs: List[SiteGroup]) -> Tuple[SiteGroup, float]: best_similarity = csim best_match = sg if best_match is None: - raise ValueError("No matching species found.") + msg = "No matching species found." + raise ValueError(msg) return best_match, best_similarity -def _get_broundary(arr, n_max=16, n_skip=3) -> int: +def _get_broundary(arr: list, n_max: int = 16, n_skip: int = 3) -> int: """Get the boundary index for the high-distortion indices. Assuming arr is sorted in reverse order, @@ -314,7 +340,9 @@ def _get_broundary(arr, n_max=16, n_skip=3) -> int: def get_weighted_average_position( - lattice: Lattice, frac_positions: ArrayLike, weights: ArrayLike | None = None + lattice: Lattice, + frac_positions: ArrayLike, + weights: ArrayLike | None = None, ) -> NDArray: """Get the weighted average position of a set of positions in frac coordinates. @@ -338,7 +366,8 @@ def get_weighted_average_position( if weights is None: weights = [1.0] * len(frac_positions) if len(frac_positions) != len(weights): - raise ValueError("The number of positions and weights must be the same.") + msg = "The number of positions and weights must be the same." + raise ValueError(msg) # TODO: can be replaced with the zip(..., strict=True) syntax in Python 3.10 pos_weights = list(zip(frac_positions, weights)) diff --git a/pymatgen/analysis/defects/generators.py b/pymatgen/analysis/defects/generators.py index 142c7160..08bb8f17 100644 --- a/pymatgen/analysis/defects/generators.py +++ b/pymatgen/analysis/defects/generators.py @@ -7,7 +7,7 @@ import logging from abc import ABCMeta from itertools import combinations -from typing import TYPE_CHECKING, Generator +from typing import TYPE_CHECKING from monty.json import MSONable from pymatgen.analysis.defects.core import Interstitial, Substitution, Vacancy @@ -21,7 +21,7 @@ from pymatgen.symmetry.analyzer import SpacegroupAnalyzer if TYPE_CHECKING: - from typing import Sequence + from collections.abc import Generator, Sequence from pymatgen.analysis.defects.core import Defect from pymatgen.io.vasp import VolumetricData @@ -46,12 +46,12 @@ def _space_group_analyzer(self, structure: Structure) -> SpacegroupAnalyzer: symprec=self.symprec, angle_tolerance=self.angle_tolerance, ) - else: # pragma: no cover - raise ValueError( - "This generator is using the `SpaceGroupAnalyzer` and requires `symprec` and `angle_tolerance` to be set." - ) + msg = "This generator is using the `SpaceGroupAnalyzer` and requires `symprec` and `angle_tolerance` to be set." + raise ValueError( + msg, + ) - def generate(self, *args, **kwargs) -> Generator[Defect, None, None]: + def generate(self, *args, **kwargs) -> Generator[Defect, None, None]: # noqa: ANN002 """Generate a defect. Args: @@ -63,7 +63,7 @@ def generate(self, *args, **kwargs) -> Generator[Defect, None, None]: """ raise NotImplementedError - def get_defects(self, *args, **kwargs) -> list[Defect]: + def get_defects(self, *args, **kwargs) -> list[Defect]: # noqa: ANN002 """Alias for self.generate.""" return list(self.generate(*args, **kwargs)) @@ -82,7 +82,7 @@ def __init__( self, symprec: float = 0.01, angle_tolerance: float = 5, - ): + ) -> None: """Initialize the vacancy generator.""" self.symprec = symprec self.angle_tolerance = angle_tolerance @@ -104,14 +104,12 @@ def generate( Generator[Vacancy, None, None]: Generator that yields a list of ``Vacancy`` objects. """ all_species = [*map(_element_str, structure.composition.elements)] - if rm_species is None: - rm_species = all_species - else: - rm_species = [*map(str, rm_species)] + rm_species = all_species if rm_species is None else [*map(str, rm_species)] if not set(rm_species).issubset(all_species): + msg = f"rm_species({rm_species}) must be a subset of the structure's species ({all_species})." raise ValueError( - f"rm_species({rm_species}) must be a subset of the structure's species ({all_species})." + msg, ) sga = self._space_group_analyzer(structure) @@ -136,13 +134,16 @@ class SubstitutionGenerator(DefectGenerator): """ - def __init__(self, symprec: float = 0.01, angle_tolerance: float = 5): + def __init__(self, symprec: float = 0.01, angle_tolerance: float = 5) -> None: """Initialize the substitution generator.""" self.symprec = symprec self.angle_tolerance = angle_tolerance def generate( - self, structure: Structure, substitution: dict[str, str | list], **kwargs + self, + structure: Structure, + substitution: dict[str, str | list], + **kwargs, ) -> Generator[Substitution, None, None]: """Generate subsitutional defects. @@ -161,7 +162,7 @@ def generate( for site_group in sym_struct.equivalent_sites: site = site_group[0] el_str = _element_str(site.specie) - if el_str not in substitution.keys(): + if el_str not in substitution: continue sub_el = substitution[el_str] if isinstance(sub_el, str): @@ -217,7 +218,7 @@ class AntiSiteGenerator(DefectGenerator): angle_tolerance: Angle tolerance for symmetry finding (parameter for ``SpacegroupAnalyzer``). """ - def __init__(self, symprec: float = 0.01, angle_tolerance: float = 5): + def __init__(self, symprec: float = 0.01, angle_tolerance: float = 5) -> None: """Initialize the anti-site generator.""" self.symprec = symprec self.angle_tolerance = angle_tolerance @@ -239,7 +240,7 @@ def generate( for u, v in combinations(all_species, 2): subs[u].append(v) subs[v].append(u) - _logger.debug(f"All anti-site pairings: {subs}") + _logger.debug("All anti-site pairings: %s", subs) for site, species in subs.items(): for sub in species: yield from self._sub_gen.generate(structure, {site: sub}, **kwargs) @@ -287,7 +288,7 @@ def generate( el_str: [ [insertions[el_str][i]] for i in range(len(insertions[el_str])) ] - for el_str in insertions.keys() + for el_str in insertions } for el_str, coords in insertions.items(): @@ -295,11 +296,15 @@ def generate( mul = multiplicities[el_str][i] equiv_positions = equivalent_positions[el_str][i] isite = PeriodicSite( - species=Species(el_str), coords=coord, lattice=structure.lattice + species=Species(el_str), + coords=coord, + lattice=structure.lattice, ) equiv_sites = [ PeriodicSite( - species=Species(el_str), coords=coord, lattice=structure.lattice + species=Species(el_str), + coords=coord, + lattice=structure.lattice, ) for coord in equiv_positions ] @@ -312,7 +317,9 @@ def generate( ) def _filter_colliding( - self, fcoords: Sequence[Sequence[float]], structure: Structure + self, + fcoords: Sequence[Sequence[float]], + structure: Structure, ) -> Generator[tuple[int, Sequence[float]], None, None]: """Check the sites for collisions. @@ -320,11 +327,13 @@ def _filter_colliding( fcoords: List of fractional coordinates of the sites. structure: The bulk structure the interstitials placed in. """ - unique_fcoords = set(tuple(f) for f in fcoords) + unique_fcoords = {tuple(f) for f in fcoords} cleaned_fcoords = remove_collisions( - fcoords=list(unique_fcoords), structure=structure, min_dist=self.min_dist + fcoords=list(unique_fcoords), + structure=structure, + min_dist=self.min_dist, ) - cleaned_fcoords = set(tuple(f) for f in cleaned_fcoords) + cleaned_fcoords = {tuple(f) for f in cleaned_fcoords} for i, fc in enumerate(fcoords): if tuple(fc) not in cleaned_fcoords: continue @@ -375,7 +384,8 @@ def generate( # type: ignore[override] **kwargs: Additional keyword arguments for the ``Interstitial`` constructor. """ if len(set(insert_species)) != len(insert_species): # pragma: no cover - raise ValueError("Insert species must be unique.") + msg = "Insert species must be unique." + raise ValueError(msg) cand_sites_mul_and_equiv_fpos = [*self._get_candidate_sites(structure)] for species in insert_species: cand_sites, multiplicity, equiv_fpos = zip(*cand_sites_mul_and_equiv_fpos) @@ -389,7 +399,8 @@ def generate( # type: ignore[override] ) def _get_candidate_sites( - self, structure: Structure + self, + structure: Structure, ) -> Generator[tuple[list[float], int, list[list[float]]], None, None]: """Get the candidate sites for interstitials. @@ -409,9 +420,9 @@ def _get_candidate_sites( angle_tol=self.angle_tol, **self.top_kwargs, ) - insert_sites = dict() - multiplicity: dict[int, int] = dict() - equiv_fpos: dict[int, list[list[float]]] = dict() + insert_sites = {} + multiplicity: dict[int, int] = {} + equiv_fpos: dict[int, list[list[float]]] = {} for fpos, lab in top.labeled_sites: if lab in insert_sites: multiplicity[lab] += 1 @@ -421,7 +432,7 @@ def _get_candidate_sites( multiplicity[lab] = 1 equiv_fpos[lab] = [fpos] - for key in insert_sites.keys(): + for key in insert_sites: yield insert_sites[key], multiplicity[key], equiv_fpos[key] @@ -462,7 +473,10 @@ def __init__( super().__init__(min_dist=min_dist) def generate( # type: ignore[override] - self, chgcar: VolumetricData, insert_species: set[str] | list[str], **kwargs + self, + chgcar: VolumetricData, + insert_species: set[str] | list[str], + **kwargs, ) -> Generator[Interstitial, None, None]: """Generate interstitials. @@ -472,7 +486,8 @@ def generate( # type: ignore[override] **kwargs: Additional keyword arguments for the ``Interstitial`` constructor. """ if len(set(insert_species)) != len(insert_species): # pragma: no cover - raise ValueError("Insert species must be unique.") + msg = "Insert species must be unique." + raise ValueError(msg) cand_sites_mul_and_equiv_fpos = [*self._get_candidate_sites(chgcar)] for species in insert_species: cand_sites, multiplicity, equiv_fpos = zip(*cand_sites_mul_and_equiv_fpos) @@ -489,7 +504,7 @@ def generate( # type: ignore[override] **kwargs, ) - def _get_candidate_sites(self, chgcar: Chgcar): + def _get_candidate_sites(self, chgcar: Chgcar) -> Generator[tuple, None, None]: cia = ChargeInsertionAnalyzer( chgcar, clustering_tol=self.clustering_tol, @@ -499,7 +514,8 @@ def _get_candidate_sites(self, chgcar: Chgcar): min_dist=self.min_dist, ) avg_chg_groups = cia.filter_and_group( - avg_radius=self.avg_radius, max_avg_charge=self.max_avg_charge + avg_radius=self.avg_radius, + max_avg_charge=self.max_avg_charge, ) for _, g in avg_chg_groups: yield min(g), len(g), g @@ -511,7 +527,7 @@ def generate_all_native_defects( vac_generator: VacancyGenerator | None = None, int_generator: ChargeInterstitialGenerator | None = None, max_insertions: int | None = None, -): +) -> Generator[Defect, None, None]: """Generate all native defects. Convenience function to generate all native defects for a host structure or chgcar object. @@ -531,7 +547,8 @@ def generate_all_native_defects( struct = host chgcar = None else: - raise ValueError("Host must be a Structure or Chgcar object.") + msg = "Host must be a Structure or Chgcar object." + raise ValueError(msg) species = set(map(_element_str, struct.species)) sub_generator = sub_generator or SubstitutionGenerator() @@ -545,7 +562,7 @@ def generate_all_native_defects( # generate interstitials if a chgcar is provided if chgcar is not None: int_generator = int_generator or ChargeInterstitialGenerator( - max_insertions=max_insertions + max_insertions=max_insertions, ) yield from int_generator.generate(chgcar, insert_species=species) @@ -554,10 +571,10 @@ def _element_str(sp_or_el: Species | Element) -> str: """Convert a species or element to a string.""" if isinstance(sp_or_el, Species): return str(sp_or_el.element) - elif isinstance(sp_or_el, Element): + if isinstance(sp_or_el, Element): return str(sp_or_el) - else: - raise ValueError(f"{sp_or_el} is not a species or element") # pragma: no cover + msg = f"{sp_or_el} is not a species or element" + raise ValueError(msg) def _remove_oxidation_states(structure: Structure) -> Structure: diff --git a/pymatgen/analysis/defects/plotting/optics.py b/pymatgen/analysis/defects/plotting/optics.py index c82660e3..d7f0047a 100644 --- a/pymatgen/analysis/defects/plotting/optics.py +++ b/pymatgen/analysis/defects/plotting/optics.py @@ -9,10 +9,13 @@ import numpy as np import pandas as pd from matplotlib import pyplot as plt -from matplotlib.colors import Normalize +from matplotlib.colors import Colormap, Normalize from pymatgen.electronic_structure.core import Spin if TYPE_CHECKING: + from collections.abc import Sequence + + from matplotlib.axes import Axes from pymatgen.analysis.defects.ccd import HarmonicDefect __author__ = "Jimmy Shen" @@ -27,16 +30,16 @@ def plot_optical_transitions( defect: HarmonicDefect, kpt_index: int = 0, band_window: int = 5, - user_defect_band: tuple = tuple(), + user_defect_band: tuple = (), other_defect_bands: list[int] | None = None, - ijdirs: list[tuple] | None = None, + ijdirs: Sequence[tuple] | None = None, shift_eig: dict[tuple, float] | None = None, x0: float = 0, x_width: float = 2, - ax=None, - cmap=None, - norm=None, -): + ax: Axes = None, + cmap: Colormap = None, + norm: Normalize = None, +) -> tuple[pd.DataFrame, Colormap, Normalize]: """Plot the optical transitions from the defect state to all other states. Only plot the transitions for a specific kpoint index. The arrows present the transitions @@ -78,6 +81,18 @@ def plot_optical_transitions( norm: The matplotlib normalization to use for the color map of the arrows. + Returns: + A pandas dataframe with the following columns: + ib: The band index of the state the arrow is pointing to. + jb: The band index of the defect state. + kpt: The kpoint index of the state the arrow is pointing to. + spin: The spin index of the state the arrow is pointing to. + eig: The eigenvalue of the state the arrow is pointing to. + M.E.: The matrix element of the transition. + cmap: + The matplotlib color map used. + norm: + The matplotlib normalization used. """ d_eigs = get_bs_eigenvalues( defect=defect, @@ -91,16 +106,17 @@ def plot_optical_transitions( defect_band_index = user_defect_band[0] else: defect_band_index = next( - filter(lambda x: x[1] == kpt_index, defect.defect_band) + filter(lambda x: x[1] == kpt_index, defect.defect_band), )[0] - if ax is None: - ax_ = plt.gca() - else: # pragma: no cover - ax_ = ax + ax_ = plt.gca() if ax is None else ax _plot_eigs( - d_eigs, defect.relaxed_bandstructure.efermi, ax=ax_, x0=x0, x_width=x_width + d_eigs, + defect.relaxed_bandstructure.efermi, + ax=ax_, + x0=x0, + x_width=x_width, ) - ijdirs = ijdirs or [(0, 0), (1, 1), (2, 2)] + ijdirs = ijdirs or ((0, 0), (1, 1), (2, 2)) me_plot_data, cmap, norm = _plot_matrix_elements( defect.waveder.cder, d_eigs, @@ -151,7 +167,8 @@ def get_bs_eigenvalues( Dictionary of the format: (iband, ikpt, ispin) -> eigenvalue """ if defect.relaxed_bandstructure is None: # pragma: no cover - raise ValueError("The defect object does not have a band structure.") + msg = "The defect object does not have a band structure." + raise ValueError(msg) other_defect_bands = other_defect_bands or [] @@ -162,7 +179,7 @@ def get_bs_eigenvalues( band_index, kpt_index, spin_index = def_indices spin_key = Spin.up if spin_index == 0 else Spin.down - output: dict[tuple, float] = dict() + output: dict[tuple, float] = {} shift_dict: dict = collections.defaultdict(lambda: 0.0) if shift_eig is not None: shift_dict.update(shift_eig) @@ -178,8 +195,8 @@ def get_bs_eigenvalues( def _plot_eigs( d_eigs: dict[tuple, float], - e_fermi=None, - ax=None, + e_fermi: float | None = None, + ax: Axes = None, x0: float = 0.0, x_width: float = 0.3, **kwargs, @@ -215,26 +232,34 @@ def _plot_eigs( eigs_ = eigenvalues[eigenvalues <= e_fermi] ax.hlines( - eigs_, x0 - (x_width / 2.0), x0 + (x_width / 2.0), color=colors[0], **kwargs + eigs_, + x0 - (x_width / 2.0), + x0 + (x_width / 2.0), + color=colors[0], + **kwargs, ) eigs_ = eigenvalues[eigenvalues > e_fermi] ax.hlines( - eigs_, x0 - (x_width / 2.0), x0 + (x_width / 2.0), color=colors[1], **kwargs + eigs_, + x0 - (x_width / 2.0), + x0 + (x_width / 2.0), + color=colors[1], + **kwargs, ) def _plot_matrix_elements( - cder, - d_eig, - defect_band_index, - ijdirs=((0, 0), (1, 1), (2, 2)), - ax=None, - x0=0, - x_width=0.6, - arrow_width=0.1, - cmap=None, - norm=None, -) -> tuple[list[tuple], plt.cm, plt.Normalize]: + cder: dict[tuple, float], + d_eig: dict[tuple, float], + defect_band_index: int, + ijdirs: Sequence[tuple] = ((0, 0), (1, 1), (2, 2)), + ax: Axes = None, + x0: float = 0.0, + x_width: float = 0.6, + arrow_width: float = 0.1, + cmap: Colormap = None, + norm: Normalize = None, +) -> tuple[list[tuple], Colormap, Normalize]: """Plot arrow for the transition from the defect state to all other states. Args: @@ -281,7 +306,7 @@ def _plot_matrix_elements( for idir, jdir in ijdirs: A += np.abs( cder[ib, jb, ik, ispin, idir] - * np.conjugate(cder[ib, jb, ik, ispin, jdir]) + * np.conjugate(cder[ib, jb, ik, ispin, jdir]), ) plot_data.append((jb, ib, eig, A)) @@ -299,7 +324,7 @@ def _plot_matrix_elements( n_arrows = len(plot_data) x_step = x_width / n_arrows x = x0 - x_width / 2 + x_step / 2 - for ib, jb, eig, A in plot_data: + for _ib, _jb, eig, A in plot_data: ax.arrow( x=x, y=y0, @@ -336,10 +361,10 @@ def _get_dataframe(d_eigs: dict, me_plot_data: list[tuple]) -> pd.DataFrame: M.E.: The matrix element of the transition. """ _, ikpt, ispin = next(iter(d_eigs.keys())) - df = pd.DataFrame( + output_dataframe = pd.DataFrame( me_plot_data, columns=["ib", "jb", "eig", "M.E."], ) - df["kpt"] = ikpt - df["spin"] = ispin - return df + output_dataframe["kpt"] = ikpt + output_dataframe["spin"] = ispin + return output_dataframe diff --git a/pymatgen/analysis/defects/plotting/phases.py b/pymatgen/analysis/defects/plotting/phases.py index 7aacb6c8..99d88b0b 100644 --- a/pymatgen/analysis/defects/plotting/phases.py +++ b/pymatgen/analysis/defects/plotting/phases.py @@ -21,9 +21,8 @@ from labellines import labelLines except ImportError: - def labelLines(*args, **kwargs): + def labelLines(*args, **kwargs) -> None: # noqa: ARG001, ANN002 """Dummy function if labellines is not installed.""" - pass __author__ = "Jimmy Shen" @@ -43,7 +42,7 @@ def plot_chempot_2d( label_lines: bool = False, x_vals: list[float] | None = None, label_fontsize: int = 12, -): +) -> None: """Plot the chemical potential diagram for two elements. Args: @@ -104,7 +103,7 @@ def _convex_hull_2d( x_element: Element, y_element: Element, competing_phases: list | None = None, -) -> list[dict]: +) -> list: """Compute the convex hull of a set of points in 2D. Args: @@ -128,9 +127,8 @@ def _convex_hull_2d( xy_points = [(pt[x_element], pt[y_element]) for pt in points] hull = ConvexHull(xy_points) xy_hull = [xy_points[i] for i in hull.vertices] - pt_and_phase = [] - def _get_line_data(i1, i2): + def _get_line_data(i1: int, i2: int) -> tuple: cp1 = competing_phases[hull.vertices[i1]] cp2 = competing_phases[hull.vertices[i2]] shared_keys = cp1.keys() & cp2.keys() @@ -138,7 +136,8 @@ def _get_line_data(i1, i2): return xy_hull[i1], xy_hull[i2], shared_phase # return all pairs of points: - for itr in range(1, len(hull.vertices)): - pt_and_phase.append(_get_line_data(itr - 1, itr)) + pt_and_phase = [ + _get_line_data(itr - 1, itr) for itr in range(1, len(hull.vertices)) + ] pt_and_phase.append(_get_line_data(len(hull.vertices) - 1, 0)) return pt_and_phase diff --git a/pymatgen/analysis/defects/recombination.py b/pymatgen/analysis/defects/recombination.py index 2cfa0fe5..8764338a 100644 --- a/pymatgen/analysis/defects/recombination.py +++ b/pymatgen/analysis/defects/recombination.py @@ -7,7 +7,7 @@ import itertools import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable import numpy as np from scipy.interpolate import PchipInterpolator @@ -21,7 +21,7 @@ except ImportError: _logger.warning("Numba not installed. Install Numba for better performance.") - def njit(*args, **kwargs): + def njit(*args, **kwargs) -> Callable: # noqa: ARG001, ANN002 """Dummy decorator for njit.""" return lambda func: func @@ -52,7 +52,7 @@ def fact(n: int) -> float: # pragma: no cover """ if n > 20: return LOOKUP_TABLE[-1] * np.prod( - np.array(list(range(21, n + 1)), dtype=np.double) + np.array(list(range(21, n + 1)), dtype=np.double), ) return LOOKUP_TABLE[n] @@ -85,7 +85,11 @@ def herm(x: float, n: int) -> float: # pragma: no cover @njit(cache=True) def analytic_overlap_NM( - dQ: float, omega1: float, omega2: float, n1: int, n2: int + dQ: float, + omega1: float, + omega2: float, + n1: int, + n2: int, ) -> float: # pragma: no cover """Compute the overlap between two displaced harmonic oscillators. @@ -165,7 +169,7 @@ def get_mQn( m_init: int, Nf: int, ovl: npt.NDArray, -): +) -> tuple[npt.ArrayLike, npt.ArrayLike]: """Get the matrix element values for the position operator. @@ -205,7 +209,7 @@ def get_mn( m_init: int, en_final: float, en_pad: float = 0.5, -): +) -> tuple[npt.ArrayLike, npt.ArrayLike]: """Get the matrix element values for the position operator. @@ -232,7 +236,11 @@ def get_mn( matels = np.zeros_like(E) for n in range(n_min, n_max): matels[n - n_min] = analytic_overlap_NM( - dQ=dQ, omega1=omega_i, omega2=omega_f, n1=m_init, n2=n + dQ=dQ, + omega1=omega_i, + omega2=omega_f, + n1=m_init, + n2=n, ) return E, matels @@ -243,7 +251,7 @@ def pchip_eval( y_coarse: npt.ArrayLike, pad_frac: float = 0.2, n_points: int = 5000, -): +) -> npt.ArrayLike: """Evaluate a piecewise cubic Hermite interpolant. Assuming a function is evenly sampleded on (``x_coarse``, ``y_coarse``), @@ -318,10 +326,19 @@ def get_SRH_coef( rate = np.zeros_like(T, dtype=np.longdouble) for m in range(Ni): E, me = get_mQn( - dQ=dQ, omega_i=omega_i, omega_f=omega_f, m_init=m, Nf=Nf, ovl=ovl + dQ=dQ, + omega_i=omega_i, + omega_f=omega_f, + m_init=m, + Nf=Nf, + ovl=ovl, ) interp_me = pchip_eval( - dE, E, np.abs(np.conj(me) * me), pad_frac=0.2, n_points=5000 + dE, + E, + np.abs(np.conj(me) * me), + pad_frac=0.2, + n_points=5000, ) rate += weights[m, :] * interp_me return 2 * np.pi * g * elph_me**2 * volume * rate diff --git a/pymatgen/analysis/defects/supercells.py b/pymatgen/analysis/defects/supercells.py index d6d18ca9..aaa84203 100644 --- a/pymatgen/analysis/defects/supercells.py +++ b/pymatgen/analysis/defects/supercells.py @@ -10,17 +10,14 @@ from monty.dev import deprecated from pymatgen.analysis.structure_matcher import ElementComparator, StructureMatcher from pymatgen.core import Lattice -from pymatgen.util.coord_cython import is_coord_subset_pbc, pbc_shortest_vectors -from pyrho.charge_density import ChargeDensity +from pymatgen.util.coord_cython import pbc_shortest_vectors # from ase.build import find_optimal_cell_shape, get_deviation_from_optimal_cell_shape # from pymatgen.io.ase import AseAtomsAdaptor if TYPE_CHECKING: - import numpy.typing as npt from numpy.typing import ArrayLike, NDArray from pymatgen.core import Structure - from pymatgen.io.vasp.outputs import VolumetricData __author__ = "Jimmy-Xuan Shen" __copyright__ = "Copyright 2022, The Materials Project" @@ -55,19 +52,20 @@ def get_sc_fromstruct( Returns: struc_sc: Supercell that is as close to cubic as possible """ - sc_mat = _cubic_cell( + return _cubic_cell( base_struct, min_atoms, max_atoms=max_atoms, min_length=min_length, force_diagonal=force_diagonal, ) - return sc_mat def get_matched_structure_mapping_old( - uc_struct: Structure, sc_struct: Structure, sm: StructureMatcher | None = None -): + uc_struct: Structure, + sc_struct: Structure, + sm: StructureMatcher | None = None, +) -> tuple[NDArray, ArrayLike] | None: # pragma: no cover """Get the mapping of the supercell to the unit cell. Get the mapping from the supercell structure onto the base structure, @@ -87,7 +85,10 @@ def get_matched_structure_mapping_old( fu, _ = sm._get_supercell_size(s1, s2) try: val, dist, sc_m, total_t, mapping = sm._strict_match( - s1, s2, fu=fu, s1_supercell=True + s1, + s2, + fu=fu, + s1_supercell=True, ) except TypeError: return None @@ -96,8 +97,10 @@ def get_matched_structure_mapping_old( @deprecated(message="This function was reworked in Feb 2024") def get_matched_structure_mapping( - uc_struct: Structure, sc_struct: Structure, sm: StructureMatcher | None = None -): + uc_struct: Structure, + sc_struct: Structure, + sm: StructureMatcher | None = None, +) -> tuple[NDArray, ArrayLike] | None: """Get the mapping of the supercell to the unit cell. Get the mapping from the supercell structure onto the base structure, @@ -113,7 +116,9 @@ def get_matched_structure_mapping( """ if sm is None: sm = StructureMatcher( - primitive_cell=False, comparator=ElementComparator(), attempt_supercell=True + primitive_cell=False, + comparator=ElementComparator(), + attempt_supercell=True, ) s1, s2 = sm._process_species([sc_struct.copy(), uc_struct.copy()]) trans = sm.get_transformation(s1, s2) @@ -160,16 +165,22 @@ def _cubic_cell( try: cst.apply_transformation(base_struct) - except BaseException: + except AttributeError: return _ase_cubic( - base_struct, min_atoms=min_atoms, max_atoms=max_atoms, min_length=min_length + base_struct, + min_atoms=min_atoms, + max_atoms=max_atoms, + min_length=min_length, ) return cst.transformation_matrix def _ase_cubic( - base_structure, min_atoms: int = 80, max_atoms: int = 240, min_length=10.0 -): + base_structure: Structure, + min_atoms: int = 80, + max_atoms: int = 240, + min_length: float = 10.0, +) -> NDArray: """Generate the best supercell from a unit cell. Use ASE's find_optimal_cell_shape function to find the best supercell. @@ -186,7 +197,7 @@ def _ase_cubic( from ase.build import find_optimal_cell_shape, get_deviation_from_optimal_cell_shape from pymatgen.io.ase import AseAtomsAdaptor - _logger.warn("ASE cubic supercell generation.") + _logger.warning("ASE cubic supercell generation.") aaa = AseAtomsAdaptor() ase_atoms = aaa.get_atoms(base_structure) @@ -194,33 +205,38 @@ def _ase_cubic( upper = math.floor(max_atoms / base_structure.num_sites) min_dev = (float("inf"), None) for size in range(lower, upper + 1): - _logger.warn(f"Trying size {size} out of {upper}.") + _logger.warning("Trying size %s", size) sc = find_optimal_cell_shape( - ase_atoms.cell, target_size=size, target_shape="sc" + ase_atoms.cell, + target_size=size, + target_shape="sc", ) sc_cell = aaa.get_atoms(base_structure * sc).cell lattice_lens = np.linalg.norm(sc_cell, axis=1) - _logger.warn(f"{lattice_lens}, {min_length}, {min_dev}") if min(lattice_lens) < min_length: continue deviation = get_deviation_from_optimal_cell_shape(sc_cell, target_shape="sc") min_dev = min(min_dev, (deviation, sc)) if min_dev[1] is None: - raise RuntimeError("Could not find a cubic supercell") + msg = "Could not find a cubic supercell" + raise RuntimeError(msg) return min_dev[1] -def _avg_lat(l1, l2): +def _avg_lat(l1: Lattice, l2: Lattice) -> Lattice: """Get the average lattice from two lattices.""" params = (np.array(l1.parameters) + np.array(l2.parameters)) / 2 return Lattice.from_parameters(*params) -def _lowest_dist(struct, ref_struct): +def _lowest_dist(struct: Structure, ref_struct: Structure) -> ArrayLike: """For each site, return the lowest distance to any site in the reference structure.""" avg_lat = _avg_lat(struct.lattice, ref_struct.lattice) _, d_2 = pbc_shortest_vectors( - avg_lat, struct.frac_coords, ref_struct.frac_coords, return_d2=True + avg_lat, + struct.frac_coords, + ref_struct.frac_coords, + return_d2=True, ) return np.min(d_2, axis=1) @@ -230,7 +246,7 @@ def get_closest_sc_mat( sc_struct: Structure, sm: StructureMatcher | None = None, debug: bool = False, -): +) -> NDArray: """Get the best guess for the supercell matrix that created this defect cell. Args: @@ -249,10 +265,10 @@ def get_closest_sc_mat( fu = int(np.round(sc_struct.lattice.volume / uc_struct.lattice.volume)) candidate_lattices = tuple( - sm._get_lattices(sc_struct.lattice, uc_struct, supercell_size=fu) + sm._get_lattices(sc_struct.lattice, uc_struct, supercell_size=fu), ) - def _get_mean_dist(lattice, sc_mat): + def _get_mean_dist(lattice: Lattice, sc_mat: NDArray) -> float: if ( np.dot(np.cross(lattice.matrix[0], lattice.matrix[1]), lattice.matrix[2]) < 0 diff --git a/pymatgen/analysis/defects/thermo.py b/pymatgen/analysis/defects/thermo.py index e2a62474..27c42cb8 100644 --- a/pymatgen/analysis/defects/thermo.py +++ b/pymatgen/analysis/defects/thermo.py @@ -7,7 +7,7 @@ from dataclasses import dataclass, field from itertools import chain, groupby from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable import numpy as np from matplotlib import pyplot as plt @@ -23,14 +23,18 @@ from pymatgen.core import Composition, Element from pymatgen.electronic_structure.dos import FermiDos from pymatgen.entries.computed_entries import ComputedEntry -from pymatgen.io.vasp import Chgcar, Locpot, Vasprun, VolumetricData +from pymatgen.io.vasp import Locpot, Vasprun, VolumetricData from pyrho.charge_density import get_volumetric_like_sc from scipy.constants import value as _cd from scipy.optimize import bisect from scipy.spatial import ConvexHull if TYPE_CHECKING: + from collections.abc import Generator, Sequence + + from matplotlib.axes import Axes from numpy.typing import ArrayLike, NDArray + from pandas import DataFrame from pymatgen.analysis.defects.utils import CorrectionResult from pymatgen.core import Structure from pymatgen.electronic_structure.dos import Dos @@ -96,8 +100,8 @@ def get_freysoldt_correction( defect_locpot: Locpot | dict, bulk_locpot: Locpot | dict, dielectric: float | NDArray, - defect_struct: Optional[Structure] = None, - bulk_struct: Optional[Structure] = None, + defect_struct: Structure | None = None, + bulk_struct: Structure | None = None, **kwargs, ) -> CorrectionResult: """Calculate the Freysoldt correction. @@ -120,7 +124,7 @@ def get_freysoldt_correction( bulk_struct: The bulk structure. If None, the structure of the bulk_locpot will be used. - kwargs: + **kwargs: Additional keyword arguments for the get_correction method. Returns: @@ -134,8 +138,9 @@ def get_freysoldt_correction( bulk_struct = getattr(bulk_locpot, "structure", None) if defect_struct is None or bulk_struct is None: # pragma: no cover + msg = "defect_struct and/or bulk_struct is missing either provide the structure or provide the complete locpot." raise ValueError( - "defect_struct and/or bulk_struct is missing either provide the structure or provide the complete locpot." + msg, ) if self.sc_defect_frac_coords is None: @@ -174,7 +179,7 @@ def get_freysoldt_correction( self.corrections.update( { "freysoldt": frey_corr.correction_energy, - } + }, ) self.corrections_metadata.update({"freysoldt": frey_corr.metadata.copy()}) return frey_corr @@ -187,9 +192,12 @@ def corrected_energy(self) -> float: def get_ediff(self) -> float | None: """Get the energy difference between the defect and the bulk (including finite-size correction).""" if self.bulk_entry is None: - raise RuntimeError( + msg = ( "Attempting to compute the energy difference without a bulk entry data." ) + raise RuntimeError( + msg, + ) return self.corrected_energy - self.bulk_entry.energy def get_summary_dict(self) -> dict: @@ -242,15 +250,15 @@ class FormationEnergyDiagram(MSONable): the convex hull. """ - defect_entries: List[DefectEntry] + defect_entries: list[DefectEntry] pd_entries: list[ComputedEntry] vbm: float - band_gap: Optional[float] = None - bulk_entry: Optional[ComputedStructureEntry] = None + band_gap: float | None = None + bulk_entry: ComputedStructureEntry | None = None inc_inf_values: bool = False bulk_stability: float = 0.001 - def __post_init__(self): + def __post_init__(self) -> None: """Post-initialization. - Reconstruct the phase diagram with the bulk entry @@ -259,41 +267,51 @@ def __post_init__(self): """ g = group_defect_entries(self.defect_entries) if next(g, True) and next(g, False): - raise ValueError( + msg = ( "Defects are not of same type! " "Use MultiFormationEnergyDiagram for multiple defect types" ) + raise ValueError( + msg, + ) # if all of the `DefectEntry` objects have the same `bulk_entry` then `self.bulk_entry` is not needed - if any(d.bulk_entry is None for d in self.defect_entries): - if self.bulk_entry is None: - raise RuntimeError( - "Not all of the `DefectEntry` objects have a `bulk_entry` attribute, you need to provide `bulk_entry` to `FormationEnergyDiagram`" - ) + if self.bulk_entry is None and any( + x.bulk_entry is None for x in self.defect_entries + ): + msg = "The bulk entry must be provided." + raise RuntimeError( + msg, + ) - bulk_entry = self.bulk_entry or self.defect_entries[0].bulk_entry - self.bulk_entry = bulk_entry + bulk_entry = self.bulk_entry or min( + [x.bulk_entry for x in self.defect_entries], + key=lambda x: x.energy_per_atom, + ) pd_ = PhaseDiagram(self.pd_entries) entries = pd_.stable_entries | {bulk_entry} pd_ = PhaseDiagram(entries) self.phase_diagram = ensure_stable_bulk(pd_, bulk_entry, self.bulk_stability) entries = [] for entry in self.phase_diagram.stable_entries: - d_ = dict( - energy=self.phase_diagram.get_form_energy(entry), - composition=entry.composition, - entry_id=entry.entry_id, - correction=0.0, - ) + d_ = { + "energy": self.phase_diagram.get_form_energy(entry), + "composition": entry.composition, + "entry_id": entry.entry_id, + "correction": 0.0, + } entries.append(ComputedEntry.from_dict(d_)) entries.append(ComputedEntry.from_dict(d_)) self.chempot_diagram = ChemicalPotentialDiagram(entries) if ( bulk_entry.composition.reduced_formula not in self.chempot_diagram.domains ): # pragma: no cover - raise ValueError( + msg = ( "Bulk entry is not stable in the chemical potential diagram." "Consider increasing the `bulk_stability` to make it more stable." ) + raise ValueError( + msg, + ) chempot_limits = self.chempot_diagram.domains[ bulk_entry.composition.reduced_formula ] @@ -320,7 +338,7 @@ def with_atomic_entries( vbm: float, bulk_entry: ComputedEntry | None = None, **kwargs, - ): + ) -> FormationEnergyDiagram: """Create a FormationEnergyDiagram object using an existing phase diagram. Since the Formation energy usually looks like: @@ -363,7 +381,8 @@ def with_atomic_entries( The FormationEnergyDiagram object. """ adjusted_entries = _get_adjusted_pd_entries( - phase_diagram=phase_diagram, atomic_entries=atomic_entries + phase_diagram=phase_diagram, + atomic_entries=atomic_entries, ) return cls( bulk_entry=bulk_entry, @@ -376,13 +395,13 @@ def with_atomic_entries( @classmethod def with_directories( cls, - directory_map: Dict[str, str], + directory_map: dict[str, str | Path], defect: Defect, pd_entries: list[ComputedEntry], dielectric: float | NDArray, vbm: float | None = None, **kwargs, - ): + ) -> FormationEnergyDiagram: """Create a FormationEnergyDiagram from VASP directories. Args: @@ -397,14 +416,16 @@ def with_directories( **kwargs: Additional keyword arguments for the constructor. """ - def _read_dir(directory): + def _read_dir(directory: str | Path) -> tuple[ComputedEntry, Locpot]: + directory = Path(directory) vr = Vasprun(get_zfile(Path(directory), "vasprun.xml")) ent = vr.get_computed_entry() locpot = Locpot.from_file(get_zfile(directory, "LOCPOT")) return ent, locpot if "bulk" not in directory_map: - raise ValueError("The bulk directory must be provided.") + msg = "The bulk directory must be provided." + raise ValueError(msg) bulk_entry, bulk_locpot = _read_dir(directory_map["bulk"]) def_entries = [] @@ -419,7 +440,9 @@ def _read_dir(directory): ) q_d_entry.get_freysoldt_correction( - defect_locpot=q_locpot, bulk_locpot=bulk_locpot, dielectric=dielectric + defect_locpot=q_locpot, + bulk_locpot=bulk_locpot, + dielectric=dielectric, ) def_entries.append(q_d_entry) if vbm is None: @@ -478,32 +501,23 @@ def _vbm_formation_energy(self, defect_entry: DefectEntry, chempots: dict) -> fl [ (self.dft_energies[Element(el)] + chempots[Element(el)]) * fac for el, fac in defect_entry.defect.element_changes.items() - ] + ], ) - if self.bulk_entry is not None: - formation_en = ( - defect_entry.corrected_energy - - self.bulk_entry.energy - - en_change - + self.vbm * defect_entry.charge_state - ) - else: - formation_en = ( - defect_entry.get_ediff() - - en_change - + self.vbm * defect_entry.charge_state - ) + try: + ediff = defect_entry.get_ediff() + except RuntimeError: + ediff = defect_entry.corrected_energy - self.bulk_entry.energy - return formation_en + return ediff - en_change + self.vbm * defect_entry.charge_state @property - def chempot_limits(self): + def chempot_limits(self) -> list[dict[Element, float]]: """Return the chemical potential limits in dictionary format.""" - res = [] - for vertex in self._chempot_limits_arr: - res.append(dict(zip(self.chempot_diagram.elements, vertex))) - return res + return [ + dict(zip(self.chempot_diagram.elements, vertex)) + for vertex in self._chempot_limits_arr + ] @property def competing_phases(self) -> list[dict[str, ComputedEntry]]: @@ -512,7 +526,7 @@ def competing_phases(self) -> list[dict[str, ComputedEntry]]: cd = self.chempot_diagram res = [] for pt in self._chempot_limits_arr: - competing_phases = dict() + competing_phases = {} for hp_ent, hp in zip(cd._hyperplane_entries, cd._hyperplanes): if hp_ent.composition.reduced_formula == bulk_formula: continue @@ -522,11 +536,11 @@ def competing_phases(self) -> list[dict[str, ComputedEntry]]: return res @property - def defect(self): + def defect(self) -> Defect: """Get the defect that this FormationEnergyDiagram represents.""" return self.defect_entries[0].defect - def _get_lines(self, chempots: Dict) -> list[tuple[float, float]]: + def _get_lines(self, chempots: dict) -> list[tuple[float, float]]: """Get the lines for the formation energy diagram. Args: @@ -549,7 +563,10 @@ def _get_lines(self, chempots: Dict) -> list[tuple[float, float]]: return lines def get_transitions( - self, chempots: dict, x_min: float = 0, x_max: float | None = None + self, + chempots: dict, + x_min: float = 0, + x_max: float | None = None, ) -> list[tuple[float, float]]: """Get the transition levels for the formation energy diagram. @@ -573,14 +590,14 @@ def get_transitions( VBM and CBM respectively. """ chempots = self._parse_chempots(chempots) - if x_max is None: + if x_max is None: # pragma: no cover x_max = self.band_gap lines = self._get_lines(chempots) lines = get_lower_envelope(lines) return get_transitions(lines, x_min, x_max) - def get_formation_energy(self, fermi_level: float, chempot_dict: dict): + def get_formation_energy(self, fermi_level: float, chempot_dict: dict) -> float: """Get the formation energy at a given Fermi level. Linearly interpolate between the transition levels. @@ -595,13 +612,16 @@ def get_formation_energy(self, fermi_level: float, chempot_dict: dict): The formation energy at the given Fermi level. """ transitions = np.array( - self.get_transitions(chempot_dict, x_min=-100, x_max=100) + self.get_transitions(chempot_dict, x_min=-100, x_max=100), ) # linearly interpolate between the set of points return np.interp(fermi_level, transitions[:, 0], transitions[:, 1]) def get_concentration( - self, fermi_level: float, chempots: dict, temperature: int | float + self, + fermi_level: float, + chempots: dict, + temperature: float, ) -> float: """Get equilibrium defect concentration assuming the dilute limit. @@ -613,19 +633,19 @@ def get_concentration( chempots = self._parse_chempots(chempots=chempots) fe = self.get_formation_energy(fermi_level, chempots) return self.defect_entries[0].defect.multiplicity * fermi_dirac( - energy=fe, temperature=temperature + energy=fe, + temperature=temperature, ) - def as_dataframe(self): + def as_dataframe(self) -> DataFrame: """Return the formation energy diagram as a pandas dataframe.""" from pandas import DataFrame defect_entries = self.defect_entries - l_ = map(lambda x: x.get_summary_dict(), defect_entries) - df = DataFrame(l_) - return df + l_ = (x.get_summary_dict() for x in defect_entries) + return DataFrame(l_) - def get_chempots(self, rich_element: Element | str, en_tol: float = 0.01): + def get_chempots(self, rich_element: Element | str, en_tol: float = 0.01) -> dict: """Get the chemical potential for a desired growth condition. Choose an element to be rich in, require the chemical potential of that element @@ -653,37 +673,39 @@ def get_chempots(self, rich_element: Element | str, en_tol: float = 0.01): max_val = max(self.chempot_limits, key=lambda x: x[rich_element])[rich_element] rich_conditions = list( filter( - lambda cp: abs(cp[rich_element] - max_val) < en_tol, self.chempot_limits - ) + lambda cp: abs(cp[rich_element] - max_val) < en_tol, + self.chempot_limits, + ), ) - if len(rich_conditions) == 0: + if len(rich_conditions) == 0: # pragma: no cover + msg = f"Cannot find a chemical potential condition with {rich_element} near zero." raise ValueError( - f"Cannot find a chemical potential condition with {rich_element} near zero." + msg, ) # defect = self.defect_entries[0].defect in_bulk = self.defect_entries[0].sc_entry.composition.elements # make sure they are of type Element - in_bulk = list(map(lambda x: Element(x.symbol), in_bulk)) + in_bulk = [Element(x.symbol) for x in in_bulk] not_in_bulk = list(set(self.chempot_limits[0].keys()) - set(in_bulk)) in_bulk = list(filter(lambda x: x != rich_element, in_bulk)) - def el_sorter(element): + def el_sorter(element: Element) -> float: return -abs(element.electron_affinity - rich_element.electron_affinity) el_list = sorted(in_bulk, key=el_sorter) + sorted(not_in_bulk, key=el_sorter) - def chempot_sorter(chempot_dict): + def chempot_sorter(chempot_dict: dict) -> tuple[float, ...]: return tuple(chempot_dict[el] for el in el_list) return min(rich_conditions, key=chempot_sorter) def __repr__(self) -> str: """Representation.""" - defect_entry_summary = [] - for dent in self.defect_entries: - defect_entry_summary.append( - f"\t{dent.defect.name} {dent.charge_state} {dent.corrected_energy}" - ) + defect_entry_summary = [ + f"\t{dent.defect.name} {dent.charge_state} {dent.corrected_energy}" + for dent in self.defect_entries + ] + txt = ( f"{self.__class__.__name__} for {self.defect.name}", "Defect Entries:", @@ -696,9 +718,9 @@ def __repr__(self) -> str: class MultiFormationEnergyDiagram(MSONable): """Container for multiple formation energy diagrams.""" - formation_energy_diagrams: List[FormationEnergyDiagram] + formation_energy_diagrams: list[FormationEnergyDiagram] - def __post_init__(self): + def __post_init__(self) -> None: """Set some attributes after initialization.""" self.band_gap = self.formation_energy_diagrams[0].band_gap self.vbm = self.formation_energy_diagrams[0].vbm @@ -741,7 +763,10 @@ def with_atomic_entries( return cls(formation_energy_diagrams=single_form_en_diagrams) def solve_for_fermi_level( - self, chempots: dict, temperature: int | float, dos: Dos + self, + chempots: dict, + temperature: float, + dos: Dos, ) -> float: """Solves for the equilibrium fermi level at a given chempot, temperature, density of states. @@ -762,7 +787,7 @@ def solve_for_fermi_level( fdos_multiplicity = fdos_factor / bulk_factor fdos_cbm, fdos_vbm = fdos.get_cbm_vbm() - def _get_chg(fd: FormationEnergyDiagram, ef): + def _get_chg(fd: FormationEnergyDiagram, ef: float) -> float: lines = fd._get_lines(chempots=chempots) return sum( fd.defect.multiplicity @@ -771,12 +796,13 @@ def _get_chg(fd: FormationEnergyDiagram, ef): for charge, vbm_fe in lines ) - def _get_total_q(ef): + def _get_total_q(ef: float) -> float: qd_tot = sum( _get_chg(fd=fd, ef=ef) for fd in self.formation_energy_diagrams ) qd_tot += fdos_multiplicity * fdos.get_doping( - fermi_level=ef + fdos_vbm, temperature=temperature + fermi_level=ef + fdos_vbm, + temperature=temperature, ) return qd_tot @@ -784,8 +810,9 @@ def _get_total_q(ef): def group_defect_entries( - defect_entries: list[DefectEntry], sm: StructureMatcher = None -): + defect_entries: list[DefectEntry], + sm: StructureMatcher = None, +) -> Generator[tuple[str, list[DefectEntry]], None, None]: """Group defect entries by their representation. First by name then by structure. @@ -800,33 +827,34 @@ def group_defect_entries( if sm is None: sm = StructureMatcher(comparator=ElementComparator()) - def _get_structure(entry): + def _get_structure(entry: DefectEntry) -> Structure: return entry.defect.defect_structure - def _get_name(entry): + def _get_name(entry: DefectEntry) -> str: return entry.defect.name - def _get_hash_no_structure(entry): + def _get_hash_no_structure(entry: DefectEntry) -> tuple[str, str]: return entry.defect.bulk_formula, entry.defect.name if all(isinstance(entry.defect, Defect) for entry in defect_entries): ent_groups = group_docs( - defect_entries, sm=sm, get_structure=_get_structure, get_hash=_get_name + defect_entries, + sm=sm, + get_structure=_get_structure, + get_hash=_get_name, ) - for g_name, g_entries in ent_groups: - yield g_name, g_entries + yield from ent_groups elif all(isinstance(entry.defect, NamedDefect) for entry in defect_entries): l_ = sorted(defect_entries, key=_get_hash_no_structure) - for _, g_entries in groupby(l_, key=_get_hash_no_structure): - similar_ents = list(g_entries) + for _, g_entries_no_struct in groupby(l_, key=_get_hash_no_structure): + similar_ents = list(g_entries_no_struct) yield similar_ents[0].defect.name, similar_ents def group_formation_energy_diagrams( feds: list[FormationEnergyDiagram], sm: StructureMatcher = None, - combine_diagrams: bool = True, -): +) -> Generator[tuple[str | None, FormationEnergyDiagram], None, None]: """Group formation energy diagrams by their representation. First by name then by structure. @@ -834,7 +862,6 @@ def group_formation_energy_diagrams( Args: feds: list of formation energy diagrams sm: StructureMatcher to use for grouping - combine_diagrams: whether to combine matching diagrams into a single diagram Returns: If combine_diagrams is True, generator of (name, combined formation energy diagram) tuples. @@ -844,24 +871,24 @@ def group_formation_energy_diagrams( if sm is None: sm = StructureMatcher(comparator=ElementComparator()) - def _get_structure(fed): + def _get_structure(fed: FormationEnergyDiagram) -> Structure: return fed.defect.defect_structure - def _get_name(fed): + def _get_name(fed: FormationEnergyDiagram) -> str: return fed.defect.name fed_groups = group_docs( - feds, sm=sm, get_structure=_get_structure, get_hash=_get_name + feds, + sm=sm, + get_structure=_get_structure, + get_hash=_get_name, ) for g_name, f_group in fed_groups: - if combine_diagrams: - fed = f_group[0] - fed_d = fed.as_dict() - dents = [dfed.defect_entries for dfed in f_group] - fed_d["defect_entries"] = list(chain.from_iterable(dents)) - yield g_name, FormationEnergyDiagram.from_dict(fed_d) - else: - yield g_name, f_group + fed = f_group[0] + fed_d = fed.as_dict() + dents = [dfed.defect_entries for dfed in f_group] + fed_d["defect_entries"] = list(chain.from_iterable(dents)) + yield g_name, FormationEnergyDiagram.from_dict(fed_d) def ensure_stable_bulk( @@ -892,10 +919,10 @@ def ensure_stable_bulk( Modified Phase diagram. """ stable_entry = ComputedEntry( - entry.composition, pd.get_hull_energy(entry.composition) - threshold + entry.composition, + pd.get_hull_energy(entry.composition) - threshold, ) - pd = PhaseDiagram([*pd.all_entries, stable_entry]) - return pd + return PhaseDiagram([*pd.all_entries, stable_entry]) def get_sc_locpot( @@ -904,7 +931,7 @@ def get_sc_locpot( grid_out: tuple, up_sample: int = 2, sm: StructureMatcher = None, -): +) -> Locpot: """Transform a unit cell locpot to be like a supercell locpot. This is useful in situations where the supercell bulk locpot is not available. @@ -923,7 +950,7 @@ def get_sc_locpot( """ sc_mat = get_closest_sc_mat(uc_locpot.structure, sc_struct=defect_struct, sm=sm) bulk_sc = uc_locpot.structure * sc_mat - sc_locpot = get_volumetric_like_sc( + return get_volumetric_like_sc( uc_locpot, bulk_sc, grid_out=grid_out, @@ -931,11 +958,12 @@ def get_sc_locpot( sm=sm, normalization=None, ) - return sc_locpot def get_transitions( - lines: list[tuple[float, float]], x_min: float, x_max: float + lines: list[tuple[float, float]], + x_min: float, + x_max: float, ) -> list[tuple[float, float]]: """Get the "transition" points in a list of lines. @@ -959,8 +987,9 @@ def get_transitions( for i, (m1, b1) in enumerate(lines[:-1]): m2, b2 = lines[i + 1] if m1 == m2: + msg = "The slopes (charge states) of the set of lines should be distinct." raise ValueError( - "The slopes (charge states) of the set of lines should be distinct." + msg, ) # pragma: no cover nx, ny = ((b2 - b1) / (m1 - m2), (m1 * b2 - m2 * b1) / (m1 - m2)) if nx < x_min: @@ -975,7 +1004,7 @@ def get_transitions( return transitions -def get_lower_envelope(lines): +def get_lower_envelope(lines: list[tuple[float, float]]) -> list[tuple[float, float]]: """Get the lower envelope of the formation energy. Based on the fact that the lower envelope of the lines is @@ -991,28 +1020,28 @@ def get_lower_envelope(lines): List lines that make up the lower envelope. """ - def _hash_float(x): + def _hash_float(x: float) -> float: return round(x, 10) - lines_dd = collections.defaultdict(lambda: float("inf")) + lines_dd: dict = collections.defaultdict(lambda: float("inf")) for m, b in lines: lines_dd[_hash_float(m)] = min(lines_dd[_hash_float(m)], b) - lines = [(m, b) for m, b in lines_dd.items()] + lines = list(lines_dd.items()) - if len(lines) < 1: - raise ValueError("Need at least one line to get lower envelope.") - elif len(lines) == 1: + if len(lines) < 1: # pragma: no cover + msg = "Need at least one line to get lower envelope." + raise ValueError(msg) + if len(lines) == 1: return lines - elif len(lines) == 2: + if len(lines) == 2: return sorted(lines) dual_points = [(m, -b) for m, b in lines] upper_hull = get_upper_hull(dual_points) - lower_envelope = [(m, -b) for m, b in upper_hull] - return lower_envelope + return [(m, -b) for m, b in upper_hull] -def get_upper_hull(points: ArrayLike) -> List[ArrayLike]: +def get_upper_hull(points: ArrayLike) -> list[ArrayLike]: """Get the upper hull of a set of points in 2D. Args: @@ -1047,7 +1076,9 @@ def get_upper_hull(points: ArrayLike) -> List[ArrayLike]: return upper_hull -def _get_adjusted_pd_entries(phase_diagram, atomic_entries) -> list[ComputedEntry]: +def _get_adjusted_pd_entries( + phase_diagram: PhaseDiagram, atomic_entries: Sequence[ComputedEntry] +) -> list[ComputedEntry]: """Get the adjusted entries for the phase diagram. Combine the terminal energies from ``atomic_entries`` with the enthalpies of formation @@ -1061,14 +1092,15 @@ def _get_adjusted_pd_entries(phase_diagram, atomic_entries) -> list[ComputedEntr List[ComputedEntry]: Entries for the new phase diagram. """ - def get_interp_en(entry: ComputedEntry): + def get_interp_en(entry: ComputedEntry) -> float: """Get the interpolated energy of an entry.""" - e_dict = dict() + e_dict = {} for e in atomic_entries: - if len(e.composition.elements) != 1: + if len(e.composition.elements) != 1: # pragma: no cover + msg = "Only single-element entries should be provided." raise ValueError( - "Only single-element entries should be provided." - ) # pragma: no cover + msg, + ) e_dict[e.composition.elements[0]] = e.energy_per_atom return sum( @@ -1078,18 +1110,18 @@ def get_interp_en(entry: ComputedEntry): adjusted_entries = [] for entry in phase_diagram.stable_entries: - d_ = dict( - energy=get_interp_en(entry) + phase_diagram.get_form_energy(entry), - composition=entry.composition, - entry_id=entry.entry_id, - correction=0, - ) + d_ = { + "energy": get_interp_en(entry) + phase_diagram.get_form_energy(entry), + "composition": entry.composition, + "entry_id": entry.entry_id, + "correction": 0, + } adjusted_entries.append(ComputedEntry.from_dict(d_)) return adjusted_entries -def fermi_dirac(energy: float, temperature: int | float) -> float: +def fermi_dirac(energy: float, temperature: float) -> float: """Get value of fermi dirac distribution. Gets the defects equilibrium concentration (up to the multiplicity factor) @@ -1105,10 +1137,10 @@ def fermi_dirac(energy: float, temperature: int | float) -> float: def plot_formation_energy_diagrams( formation_energy_diagrams: FormationEnergyDiagram - | List[FormationEnergyDiagram] + | list[FormationEnergyDiagram] | MultiFormationEnergyDiagram, rich_element: Element | None = None, - chempots: Dict | None = None, + chempots: dict | None = None, alignment: float = 0.0, xlim: list | None = None, ylim: list | None = None, @@ -1123,12 +1155,12 @@ def plot_formation_energy_diagrams( linewidth: int = 4, envelope_alpha: float = 0.8, line_alpha: float = 0.5, - band_edge_color="k", + band_edge_color: str = "k", filterfunction: Callable | None = None, legend_loc: str = "lower center", show_legend: bool = True, - axis=None, -): + axis: Axes = None, +) -> Axes: """Plot the formation energy diagram. Args: @@ -1172,12 +1204,13 @@ def plot_formation_energy_diagrams( elif isinstance(formation_energy_diagrams, FormationEnergyDiagram): formation_energy_diagrams = [formation_energy_diagrams] - filterfunction = filterfunction if filterfunction else lambda x: True + filterfunction = filterfunction if filterfunction else lambda _x: True formation_energy_diagrams = list(filter(filterfunction, formation_energy_diagrams)) band_gap = formation_energy_diagrams[0].band_gap if not xlim and band_gap is None: - raise ValueError("Must specify xlim or set band_gap attribute") + msg = "Must specify xlim or set band_gap attribute" + raise ValueError(msg) if axis is None: _, axis = plt.subplots() @@ -1202,7 +1235,7 @@ def plot_formation_energy_diagrams( named_feds.append((name_, fed_)) color_line_gen = _get_line_color_and_style(colors, linestyle) - for fid, (fed_name, single_fed) in enumerate(named_feds): + for _fid, (fed_name, single_fed) in enumerate(named_feds): cur_color, cur_style = next(color_line_gen) chempots_ = ( chempots @@ -1213,8 +1246,10 @@ def plot_formation_energy_diagrams( lowerlines = get_lower_envelope(lines) trans = np.array( get_transitions( - lowerlines, np.add(xmin, alignment), np.add(xmax, alignment) - ) + lowerlines, + np.add(xmin, alignment), + np.add(xmax, alignment), + ), ) trans_y = trans[:, 1] ymin = min(ymin, *trans_y) @@ -1245,7 +1280,10 @@ def plot_formation_energy_diagrams( x = np.linspace(xmin, xmax) y = line[0] * x + line[1] axis.plot( - np.subtract(x, alignment), y, color=cur_color, alpha=line_alpha + np.subtract(x, alignment), + y, + color=cur_color, + alpha=line_alpha, ) axis.set_xlim(xmin, xmax) @@ -1276,7 +1314,11 @@ def plot_formation_energy_diagrams( axis.axvline(0, ls="--", color="k", lw=2, alpha=0.2) axis.axvline( - np.subtract(0, alignment), ls="--", color=band_edge_color, lw=2, alpha=0.8 + np.subtract(0, alignment), + ls="--", + color=band_edge_color, + lw=2, + alpha=0.8, ) if band_gap: axis.axvline( @@ -1311,7 +1353,9 @@ def plot_formation_energy_diagrams( return axis -def _get_line_color_and_style(colors=None, styles=None): +def _get_line_color_and_style( + colors: Sequence | None = None, styles: Sequence | None = None +) -> Generator[tuple[str, str], None, None]: """Get a generator for colors and styles. Create an iterator that will cycle through the colors and styles. @@ -1335,7 +1379,7 @@ def _get_line_color_and_style(colors=None, styles=None): yield color, style -def _is_on_hyperplane(pt: np.array, hp: np.array, tol: float = 1e-8): +def _is_on_hyperplane(pt: np.array, hp: np.array, tol: float = 1e-8) -> bool: """Check if a point lies on a hyperplane. Args: diff --git a/pymatgen/analysis/defects/utils.py b/pymatgen/analysis/defects/utils.py index b77bc220..6cc0ab31 100644 --- a/pymatgen/analysis/defects/utils.py +++ b/pymatgen/analysis/defects/utils.py @@ -12,7 +12,7 @@ from copy import deepcopy from dataclasses import dataclass from functools import cached_property -from typing import TYPE_CHECKING, Any, Callable, Generator +from typing import TYPE_CHECKING, Any, Callable import numpy as np from monty.dev import deprecated @@ -30,6 +30,7 @@ from scipy.spatial.distance import squareform if TYPE_CHECKING: + from collections.abc import Generator from pathlib import Path from numpy import typing as npt @@ -78,7 +79,9 @@ class QModel(MSONable): If defect charge is more delocalized, exponential tail is suggested. """ - def __init__(self, beta=1.0, expnorm=0.0, gamma=1.0): + def __init__( + self, beta: float = 1.0, expnorm: float = 0.0, gamma: float = 1.0 + ) -> None: """Initialize the model. Args: @@ -96,9 +99,10 @@ def __init__(self, beta=1.0, expnorm=0.0, gamma=1.0): self.beta2 = beta * beta self.gamma2 = gamma * gamma if expnorm and not gamma: - raise ValueError("Please supply exponential decay constant.") + msg = "Please supply exponential decay constant." + raise ValueError(msg) - def rho_rec(self, g2): + def rho_rec(self, g2: float) -> float: """Reciprocal space model charge value. Reciprocal space model charge value, for input squared reciprocal vector. @@ -114,7 +118,7 @@ def rho_rec(self, g2): ) * np.exp(-0.25 * self.beta2 * g2) @property - def rho_rec_limit0(self): + def rho_rec_limit0(self) -> float: """Reciprocal space model charge value. Close to reciprocal vector 0 . @@ -123,7 +127,7 @@ def rho_rec_limit0(self): return -2 * self.gamma2 * self.expnorm - 0.25 * self.beta2 * (1 - self.expnorm) -def eV_to_k(energy): +def eV_to_k(energy: float) -> float: """Convert energy to reciprocal vector magnitude k via hbar*k^2/2m. Args: @@ -135,7 +139,9 @@ def eV_to_k(energy): return math.sqrt(energy / invang_to_ev) * ang_to_bohr -def genrecip(a1, a2, a3, encut) -> Generator[npt.ArrayLike, None, None]: +def genrecip( + a1: npt.ArrayLike, a2: npt.ArrayLike, a3: npt.ArrayLike, encut: float +) -> Generator[npt.ArrayLike, None, None]: """Generate reciprocal lattice vectors within the energy cutoff. Args: @@ -177,7 +183,9 @@ def genrecip(a1, a2, a3, encut) -> Generator[npt.ArrayLike, None, None]: yield vec -def generate_reciprocal_vectors_squared(a1, a2, a3, encut): +def generate_reciprocal_vectors_squared( + a1: npt.ArrayLike, a2: npt.ArrayLike, a3: npt.ArrayLike, encut: float +) -> Generator[float, None, None]: """Generate Reciprocal vectors squared. Generate reciprocal vector magnitudes within the cutoff along the specified @@ -197,7 +205,7 @@ def generate_reciprocal_vectors_squared(a1, a2, a3, encut): yield np.dot(vec, vec) -def converge(f, step, tol, max_h) -> float: +def converge(f: Callable, step: float, tol: float, max_h: float) -> float: """Simple newton iteration based convergence function. Args: @@ -219,12 +227,15 @@ def converge(f, step, tol, max_h) -> float: h += step if h > max_h: - raise Exception(f"Did not converge before {h}") + msg = f"Did not converge before {h}" + raise Exception(msg) return g def get_zfile( - directory: Path, base_name: str, allow_missing: bool = False + directory: Path, + base_name: str, + allow_missing: bool = False, ) -> Path | None: """Find gzipped or non-gzipped versions of a file in a directory listing. @@ -238,20 +249,21 @@ def get_zfile( and the file cannot be found, then ``None`` will be returned. """ for file in directory.glob(f"{base_name}*"): - if base_name == file.name: - return file - elif base_name + ".gz" == file.name: - return file - elif base_name + ".GZ" == file.name: + if ( + base_name == file.name + or base_name + ".gz" == file.name + or base_name + ".GZ" == file.name + ): return file if allow_missing: return None - raise FileNotFoundError(f"Could not find {base_name} or {base_name}.gz file.") + msg = f"Could not find {base_name} or {base_name}.gz file." + raise FileNotFoundError(msg) -def generic_group_labels(list_in, comp=operator.eq): +def generic_group_labels(list_in: list, comp: Callable = operator.eq) -> list[int]: """Group a list of unsortable objects. Args: @@ -262,7 +274,7 @@ def generic_group_labels(list_in, comp=operator.eq): list[int]: list of labels for the input list """ - list_out = [None] * len(list_in) + list_out: list[int | None] = [None] * len(list_in) label_num = 0 for i1, ls1 in enumerate(list_out): if ls1 is not None: @@ -294,10 +306,7 @@ def get_local_extrema(chgcar: VolumetricData, find_min: bool = True) -> npt.NDAr extrema_coords (list): list of fractional coordinates corresponding to local extrema. """ - if find_min: - sign = -1 - else: - sign = 1 + sign = -1 if find_min else 1 # Make 3x3x3 supercell # This is a trick to resolve the periodical boundary issue. @@ -316,7 +325,9 @@ def get_local_extrema(chgcar: VolumetricData, find_min: bool = True) -> npt.NDAr def remove_collisions( - fcoords: npt.NDArray, structure: Structure, min_dist: float = 0.9 + fcoords: npt.NDArray, + structure: Structure, + min_dist: float = 0.9, ) -> npt.NDArray: """Removed points that are too close to existing atoms in the structure. @@ -334,12 +345,14 @@ def remove_collisions( dist_matrix = structure.lattice.get_all_distances(fcoords, s_fcoord) all_dist = np.min(dist_matrix, axis=1) return np.array( - [fcoords[i] for i in range(len(fcoords)) if all_dist[i] >= min_dist] + [fcoords[i] for i in range(len(fcoords)) if all_dist[i] >= min_dist], ) def cluster_nodes( - fcoords: npt.ArrayLike, lattice: Lattice, tol: float = 0.2 + fcoords: npt.ArrayLike, + lattice: Lattice, + tol: float = 0.2, ) -> npt.NDArray: """Cluster nodes that are too close together using hiercharcal clustering. @@ -384,7 +397,9 @@ def cluster_nodes( def get_avg_chg( - chgcar: VolumetricData, fcoord: npt.ArrayLike, radius: float = 0.4 + chgcar: VolumetricData, + fcoord: npt.ArrayLike, + radius: float = 0.4, ) -> float: """Get the average charge in a sphere. @@ -401,7 +416,7 @@ def get_avg_chg( # makesure fcoord is an array fcoord = np.array(fcoord) - def _dist_mat(pos_frac): + def _dist_mat(pos_frac: npt.ArrayLike) -> npt.NDArray: # return a matrix that contains the distances aa = np.linspace(0, 1, len(chgcar.get_axis_grid(0)), endpoint=False) bb = np.linspace(0, 1, len(chgcar.get_axis_grid(1)), endpoint=False) @@ -414,11 +429,11 @@ def _dist_mat(pos_frac): return dist_from_pos.reshape(AA.shape) if np.any(fcoord < 0) or np.any(fcoord > 1): - raise ValueError("f_coords must be in [0,1)") + msg = "f_coords must be in [0,1)" + raise ValueError(msg) mask = _dist_mat(fcoord) < radius vol_sphere = chgcar.structure.volume * (mask.sum() / chgcar.ngridpts) - avg_chg = np.sum(chgcar.data["total"] * mask) / mask.size / vol_sphere - return avg_chg + return np.sum(chgcar.data["total"] * mask) / mask.size / vol_sphere class TopographyAnalyzer: @@ -440,20 +455,20 @@ class TopographyAnalyzer: def __init__( self, - structure, - framework_ions, - cations, - image_tol=0.0001, - max_cell_range=1, - check_volume=True, - constrained_c_frac=0.5, - thickness=0.5, + structure: Structure, + framework_ions: list[str], + cations: list[str], + image_tol: float = 0.0001, + max_cell_range: int = 1, + check_volume: bool = True, + constrained_c_frac: float = 0.5, + thickness: float = 0.5, clustering_tol: float = 0.5, min_dist: float = 0.9, ltol: float = 0.2, stol: float = 0.3, angle_tol: float = 5, - ): + ) -> None: """Initialize the TopographyAnalyzer. Args: @@ -510,6 +525,7 @@ def __init__( max_cell_range = 2 # Let us first map all sites to the standard unit cell, i.e., + # 0 ≤ coordinates < 1. # structure = Structure.from_sites(structure, to_unit_cell=True) # lattice = structure.lattice @@ -519,12 +535,12 @@ def __init__( # mapping all sites to the standard unit cell self.structure = structure.copy() - # TODO: Structure is still being mutated something weird is going on but the code works. + # NOTE: Structure is still being mutated something weird is going on but the code works. # remove oxidation state self.structure.remove_oxidation_states() constrained_sites = [] - for i, site in enumerate(self.structure): + for _i, site in enumerate(self.structure): if ( site.frac_coords[2] >= constrained_c_frac - thickness and site.frac_coords[2] <= constrained_c_frac + thickness @@ -567,14 +583,14 @@ def __init__( for v in vs: node_points_map[v].update(pts) - _logger.debug(f"{len(voro.vertices)} total Voronoi vertices") + _logger.debug("Voronoi vertices in cell: %s", len(voro.vertices)) # Vnodes store all the valid voronoi polyhedra. Cation vnodes store # the voronoi polyhedra that are already occupied by existing cations. vnodes: list[VoronoiPolyhedron] = [] cation_vnodes = [] - def get_mapping(poly): + def get_mapping(poly: VoronoiPolyhedron) -> VoronoiPolyhedron | None: """Helper function. Checks if a vornoi poly is a periodic image of @@ -600,7 +616,7 @@ def get_mapping(poly): if ref is None: vnodes.append(poly) - _logger.debug(f"{len(vnodes)} voronoi vertices in cell.") + _logger.debug("%s - voronoi vertices in cell.", len(vnodes)) # Eliminate all voronoi nodes which are closest to existing cations. if len(cations) > 0: @@ -616,7 +632,7 @@ def get_mapping(poly): cation_vnodes = [v for i, v in enumerate(vnodes) if i in indices] vnodes = [v for i, v in enumerate(vnodes) if i not in indices] - _logger.debug(f"{len(vnodes)} vertices in cell not with cation.") + _logger.debug("%s - vertices in cell not with cation.", len(vnodes)) self.coords = coords self.vnodes = vnodes self.cation_vnodes = cation_vnodes @@ -645,7 +661,7 @@ def labeled_sites( sm=self.sm, ) - def check_volume(self): + def check_volume(self) -> None: """Basic check for volume of all voronoi poly sum to unit cell volume. Note that this does not apply after poly combination. @@ -654,14 +670,17 @@ def check_volume(self): v.volume for v in self.cation_vnodes ) if abs(vol - self.structure.volume) > 1e-8: # pragma: no cover - raise ValueError( + msg = ( "Sum of voronoi volumes is not equal to original volume of " "structure! This may lead to inaccurate results. You need to " "tweak the tolerance and max_cell_range until you get a " "correct mapping." ) + raise ValueError( + msg, + ) - def get_structure_with_nodes(self): + def get_structure_with_nodes(self) -> Structure: """Get the modified structure with the voronoi nodes inserted. The species is set as a DummySpecies X. @@ -675,7 +694,14 @@ def get_structure_with_nodes(self): class VoronoiPolyhedron: """Convenience container for a voronoi point in PBC and its associated polyhedron.""" - def __init__(self, lattice, frac_coords, polyhedron_indices, all_coords, name=None): + def __init__( + self, + lattice: Lattice, + frac_coords: npt.ArrayLike, + polyhedron_indices: list | set, + all_coords: list, + name: str | int | None = None, + ) -> None: """Initialize a VoronoiPolyhedron. Args: @@ -717,16 +743,16 @@ def is_image(self, poly: VoronoiPolyhedron, tol: float) -> bool: return True @property - def coordination(self): + def coordination(self) -> int: """Coordination number.""" return len(self.polyhedron_indices) @property - def volume(self): + def volume(self) -> float: """Volume of the polyhedron.""" return calculate_vol(self.polyhedron_coords) - def __str__(self): + def __str__(self) -> str: """String representation.""" return f"Voronoi polyhedron {self.name}" @@ -769,7 +795,7 @@ def __init__( stol: float = 0.3, angle_tol: float = 5, min_dist: float = 0.9, - ): + ) -> None: """Initialize the ChargeInsertionAnalyzer.""" self.chgcar = chgcar self.working_ion = working_ion @@ -803,7 +829,9 @@ def local_minima(self) -> list[npt.ArrayLike]: return [s for s, _ in self.labeled_sites] def filter_and_group( - self, avg_radius: float = 0.4, max_avg_charge: float = 1.0 + self, + avg_radius: float = 0.4, + max_avg_charge: float = 1.0, ) -> list[tuple[float, list[list[float]]]]: """Filter and group the insertion sites by average charge. @@ -823,7 +851,9 @@ def filter_and_group( avg_chg_first_member = {} for lab, g in lab_groups.items(): avg_chg_first_member[lab] = get_avg_chg( - self.chgcar, fcoord=self.local_minima[g[0]], radius=avg_radius + self.chgcar, + fcoord=self.local_minima[g[0]], + radius=avg_radius, ) res = [] @@ -835,7 +865,7 @@ def filter_and_group( return res -def _get_ipr(spin, k_index, procar): +def _get_ipr(spin: int, k_index: int, procar: Procar) -> npt.NDArray: states = procar.data[spin][k_index, ...] flat_states = states.reshape(states.shape[0], -1) return 1 / np.sum(flat_states**2, axis=1) @@ -859,19 +889,20 @@ def get_ipr_in_window( Returns: dict[(int, int), npt.NDArray]: The IPR of the states in the band window keyed for each k-point and spin. """ - res = dict() - for spin in bandstructure.bands.keys(): + res = {} + for spin in bandstructure.bands: s_index = 0 if spin == Spin.up else 1 # last band that is fully below the fermi level last_occ_idx = bisect.bisect_left( - bandstructure.bands[spin].max(1), bandstructure.efermi + bandstructure.bands[spin].max(1), + bandstructure.efermi, ) lbound = max(last_occ_idx - band_window, 0) ubound = min(last_occ_idx + band_window, bandstructure.nb_bands) for k_idx, _ in enumerate(bandstructure.kpoints): ipr = _get_ipr(spin, k_idx, procar) res[(k_idx, s_index)] = np.stack( - (np.arange(lbound, ubound), ipr[lbound:ubound]) + (np.arange(lbound, ubound), ipr[lbound:ubound]), ).T return res @@ -904,7 +935,10 @@ def get_localized_states( def sort_positive_definite( - list_in: list, ref1: Any, ref2: Any, dist: Callable + list_in: list, + ref1: object, + ref2: object, + dist: Callable, ) -> tuple[tuple, tuple[float]]: """Sort a list where we can only compute a positive-definite distance. @@ -939,7 +973,7 @@ def sort_positive_definite( return sorted_list, distances -def calculate_vol(coords: npt.NDArray): +def calculate_vol(coords: npt.NDArray) -> float: """Calculate volume given a set of points in 3D space. Args: @@ -951,10 +985,9 @@ def calculate_vol(coords: npt.NDArray): return ConvexHull(coords).volume -@deprecated("Name changed") -def get_symmetry_labeled_structures(): +@deprecated("Name changed to get_labeled_inserted_structure.") +def get_symmetry_labeled_structures() -> None: """Deprecated.""" - pass def get_labeled_inserted_structure( @@ -1019,7 +1052,9 @@ class CorrectionResult(MSONable): metadata: dict[Any, Any] -def _group_docs_by_structure(docs: list, sm: StructureMatcher, get_structure: Callable): +def _group_docs_by_structure( + docs: list, sm: StructureMatcher, get_structure: Callable +) -> Generator[list, None, None]: """Group docs by structure. Args: @@ -1035,8 +1070,7 @@ def _group_docs_by_structure(docs: list, sm: StructureMatcher, get_structure: Ca comp=lambda x, y: sm.fit(get_structure(x), get_structure(y), symmetric=True), ) for ilab in set(labs): - sub_g = [docs[itr] for itr, jlab in enumerate(labs) if jlab == ilab] - yield [el for el in sub_g] + yield [docs[itr] for itr, jlab in enumerate(labs) if jlab == ilab] def group_docs( @@ -1044,7 +1078,7 @@ def group_docs( sm: StructureMatcher, get_structure: Callable, get_hash: Callable | None = None, -): +) -> Generator[tuple[str | None, list], None, None]: """Group docs by a simple hash followed by structure. Assuming that you have a basic representation of the defect, like `name`. @@ -1062,8 +1096,8 @@ def group_docs( Generator of (name, group) """ if get_hash is None: - for g in _group_docs_by_structure(docs, sm, get_structure): - yield None, g + for g_ in _group_docs_by_structure(docs, sm, get_structure): + yield None, g_ else: s_docs = sorted(docs, key=get_hash) for h, g in itertools.groupby(s_docs, key=get_hash): @@ -1107,7 +1141,7 @@ def get_plane_spacing(lattice: npt.NDArray) -> list[float]: spacing = [] for idir in range(ndim): idir_proj = [ - np.array(lattice[j]) * pproj[tuple(sorted([idir, j]))] # type: ignore + np.array(lattice[j]) * pproj[tuple(sorted([idir, j]))] # type: ignore[index] for j in range(ndim) if j != idir ] diff --git a/pyproject.toml b/pyproject.toml index 4c57bf14..6f6d02d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,57 +103,14 @@ src = ["pymatgen", "tests"] line-length = 88 indent-width = 4 - -# By default, ruff only uses all "E" (pycodestyle) and "F" (pyflakes) rules. -# Here we append to the defaults. -select = [ - # (flake8-builtins) detect shadowing of python builtin symbols by variables and arguments. - # Attributes are OK (which is why A003) is not included here. - "A001", - "A002", - # (useless expression): Expressions that aren't assigned to anything are typically bugs. - "B018", - # (pydocstyle) Docstring-related rules. A large subset of these are ignored by the - # "convention=google" setting, we set under tool.ruff.pydocstyle. - "D", - # (pycodestyle) pycodestyle rules - "E", - # (pyflakes) pyflakes rules - "F", - # (isort) detect improperly sorted imports - "I001", - # (pylint) use all pylint rules from categories "Convention", "Error", and "Warning" (ruff - # currently implements only a subset of pylint's rules) - "PLE", - "PLW", - # (no commented out code) keep commented out code blocks out of the codebase - # "ERA001", - # (ruff-specific) Enable all ruff-specific checks (i.e. not ports of - # functionality from an existing linter). - "RUF", - # (private member access) Flag access to `_`-prefixed symbols. By default the various special - # methods on `NamedTuple` are ignored (e.g. `_replace`). - "SLF001", - # (flake8-type-checking) Auto-sort imports into TYPE_CHECKING blocks depending on whether - # they are runtime or type-only imports. - "TCH", - # (banned-api) Flag use of banned APIs. See tool.ruff.flake8-tidy-imports.banned-api for details. - "TID251", - # (disallow print statements) keep debugging statements out of the codebase - "T20", - # (f-strings) use f-strings instead of .format() - "UP032", - # (invalid escape sequence) flag errant backslashes - "W605", -] - [tool.ruff.lint] -# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`). -select = ["E4", "E7", "E9", "F", "D", "I001", "TCH"] -ignore = ["E203", "E501", "F401"] +select = ["ALL"] +ignore = [ + "E203", "E501", "N", "PLR", "SLF", "ANN101", "ANN102", + "TRY", "COM812", "ISC001", "ERA001", "FBT", "C", "FIX", "TD", "ANN003"] # Allow fix for all enabled rules (when `--fix`) is provided. -fixable = ["ALL", "TCH"] +fixable = ["ALL"] unfixable = [] # Allow unused variables when underscore-prefixed. diff --git a/tests/test_ccd.py b/tests/test_ccd.py index d62bdea6..8bf2cd82 100644 --- a/tests/test_ccd.py +++ b/tests/test_ccd.py @@ -26,6 +26,8 @@ def hd0(v_ga): assert pytest.approx(hd0.distortions[1]) == 0.0 assert pytest.approx(hd0.omega_eV) == 0.03268045792725 assert hd0.defect_band == [(138, 0, 1), (138, 1, 1)] + assert hd0._get_ediff(output_order="bks").shape == (216, 2, 2) + assert hd0._get_ediff(output_order="skb").shape == (2, 2, 216) return hd0 @@ -42,6 +44,24 @@ def hd1(v_ga): assert pytest.approx(hd1.omega_eV) == 0.03341323356861477 return hd1 +def test_defect_band_raises(v_ga): + vaspruns = v_ga[(0, -1)]["vaspruns"] + procar = v_ga[(0, -1)]["procar"] + hd0 = HarmonicDefect.from_vaspruns( + vaspruns, + charge_state=0, + procar=procar, + store_bandstructure=True, + ) + # mis-matched defect band + hd0.defect_band = [(138, 0, 1), (139, 1, 1)] + with pytest.raises(ValueError) as e: + assert hd0.defect_band_index + + # mis-matched defect spin + hd0.defect_band = [(138, 0, 1), (138, 1, 0)] + with pytest.raises(ValueError) as e: + assert hd0.spin_index == 1 def test_HarmonicDefect(hd0, v_ga, test_dir): # test other basic reading functions for HarmonicDefect diff --git a/tests/test_core.py b/tests/test_core.py index 238067f9..a882392e 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -26,6 +26,7 @@ def test_vacancy(gan_struct): assert vac.name == "v_Ga" assert vac == vac assert vac.element_changes == {Element("Ga"): -1} + assert vac.latex_name == r"v$_{\rm Ga}$" def test_substitution(gan_struct): @@ -48,6 +49,7 @@ def test_substitution(gan_struct): assert sub.latex_name == r"O$_{\rm N}$" assert sub == sub assert sub.element_changes == {Element("N"): -1, Element("O"): 1} + assert sub.latex_name == r"O$_{\rm N}$" # test supercell with locking sc_locked = sub.get_supercell_structure(relax_radius=5.0) @@ -122,6 +124,7 @@ def test_interstitial(gan_struct): assert inter.name == "N_i" assert str(inter) == "N intersitial site at [0.00,0.00,0.75]" assert inter.element_changes == {Element("N"): 1} + assert inter.latex_name == r"N$_{\rm i}$" # test target_frac_coords with get_supercell_structure finder = DefectSiteFinder() diff --git a/tests/test_supercells.py b/tests/test_supercells.py index f55b4a64..1ef169a7 100644 --- a/tests/test_supercells.py +++ b/tests/test_supercells.py @@ -7,6 +7,7 @@ get_sc_fromstruct, get_closest_sc_mat ) +import pytest def test_supercells(gan_struct): @@ -24,10 +25,13 @@ def test_supercells(gan_struct): def test_ase_supercells(gan_struct): - sc_mat = _ase_cubic(gan_struct, min_atoms=4, max_atoms=8) + sc_mat = _ase_cubic(gan_struct, min_atoms=4, max_atoms=8, min_length=1.0) sc = gan_struct * sc_mat assert 4 <= sc.num_sites <= 8 + # check raise + with pytest.raises(RuntimeError): + _ase_cubic(gan_struct, min_atoms=4, max_atoms=8, min_length=10) def test_closest_sc_mat(test_dir):