diff --git a/pymatgen/core/lattice.py b/pymatgen/core/lattice.py index 7f2d3eabf67..f5bf9b7383c 100644 --- a/pymatgen/core/lattice.py +++ b/pymatgen/core/lattice.py @@ -125,7 +125,7 @@ def __format__(self, fmt_spec: str = "") -> str: def copy(self): """Deep copy of self.""" - return self.__class__(self.matrix.copy(), pbc=self.pbc) + return type(self)(self.matrix.copy(), pbc=self.pbc) @property def matrix(self) -> np.ndarray: diff --git a/pymatgen/core/spectrum.py b/pymatgen/core/spectrum.py index dc0efdcad63..b50ad021864 100644 --- a/pymatgen/core/spectrum.py +++ b/pymatgen/core/spectrum.py @@ -136,7 +136,7 @@ def copy(self): Returns: Copy of Spectrum object. """ - return self.__class__(self.x, self.y, *self._args, **self._kwargs) + return type(self)(self.x, self.y, *self._args, **self._kwargs) def __add__(self, other): """Add two Spectrum object together. Checks that x scales are the same. @@ -150,7 +150,7 @@ def __add__(self, other): """ if not all(np.equal(self.x, other.x)): raise ValueError("X axis values are not compatible!") - return self.__class__(self.x, self.y + other.y, *self._args, **self._kwargs) + return type(self)(self.x, self.y + other.y, *self._args, **self._kwargs) def __sub__(self, other): """Subtract one Spectrum object from another. Checks that x scales are @@ -165,7 +165,7 @@ def __sub__(self, other): """ if not all(np.equal(self.x, other.x)): raise ValueError("X axis values are not compatible!") - return self.__class__(self.x, self.y - other.y, *self._args, **self._kwargs) + return type(self)(self.x, self.y - other.y, *self._args, **self._kwargs) def __mul__(self, other): """Scale the Spectrum's y values. @@ -176,7 +176,7 @@ def __mul__(self, other): Returns: Spectrum object with y values scaled """ - return self.__class__(self.x, other * self.y, *self._args, **self._kwargs) + return type(self)(self.x, other * self.y, *self._args, **self._kwargs) __rmul__ = __mul__ @@ -189,7 +189,7 @@ def __truediv__(self, other): Returns: Spectrum object with y values divided """ - return self.__class__(self.x, self.y.__truediv__(other), *self._args, **self._kwargs) + return type(self)(self.x, self.y.__truediv__(other), *self._args, **self._kwargs) def __floordiv__(self, other): """True division of y. @@ -200,7 +200,7 @@ def __floordiv__(self, other): Returns: Spectrum object with y values divided """ - return self.__class__(self.x, self.y.__floordiv__(other), *self._args, **self._kwargs) + return type(self)(self.x, self.y.__floordiv__(other), *self._args, **self._kwargs) __div__ = __truediv__ diff --git a/pymatgen/core/structure.py b/pymatgen/core/structure.py index 0284ff343e7..bda6f2cac96 100644 --- a/pymatgen/core/structure.py +++ b/pymatgen/core/structure.py @@ -2083,7 +2083,7 @@ def get_reduced_structure(self, reduction_algo: Literal["niggli", "LLL"] = "nigg raise ValueError(f"Invalid {reduction_algo=}") if reduced_latt != self.lattice: - return self.__class__( + return type(self)( reduced_latt, self.species_and_occu, self.cart_coords, # type: ignore @@ -2124,7 +2124,7 @@ def copy(self, site_properties=None, sanitize=False, properties=None) -> Structu if properties: props.update(properties) if not sanitize: - return self.__class__( + return type(self)( self._lattice, self.species_and_occu, self.frac_coords, @@ -2265,9 +2265,7 @@ def interpolate( else: lattice = self.lattice fcoords = start_coords + x * vec - structs.append( - self.__class__(lattice, sp, fcoords, site_properties=self.site_properties, labels=self.labels) - ) + structs.append(type(self)(lattice, sp, fcoords, site_properties=self.site_properties, labels=self.labels)) return structs def get_miller_index_from_site_indexes(self, site_ids, round_dp=4, verbose=True): @@ -2526,11 +2524,11 @@ def to_str(x) -> str: data = [] props = self.site_properties keys = sorted(props) - for i, site in enumerate(self): - row = [str(i), site.species_string] + for idx, site in enumerate(self): + row = [str(idx), site.species_string] row.extend([to_str(j) for j in site.frac_coords]) - for k in keys: - row.append(props[k][i]) + for key in keys: + row.append(props[key][idx]) data.append(row) outs.append( tabulate( @@ -3517,7 +3515,7 @@ def get_centered_molecule(self) -> IMolecule | Molecule: """ center = self.center_of_mass new_coords = np.array(self.cart_coords) - center - return self.__class__( + return type(self)( self.species_and_occu, new_coords, charge=self._charge, diff --git a/pymatgen/core/tensors.py b/pymatgen/core/tensors.py index ec0a9569403..b1c5887bab8 100644 --- a/pymatgen/core/tensors.py +++ b/pymatgen/core/tensors.py @@ -110,7 +110,7 @@ def transform(self, symm_op): Args: symm_op (SymmOp): a symmetry operation to apply to the tensor """ - return self.__class__(symm_op.transform_tensor(self)) + return type(self)(symm_op.transform_tensor(self)) def rotate(self, matrix, tol: float = 1e-3): """Applies a rotation directly, and tests input matrix to ensure a valid @@ -257,7 +257,7 @@ def round(self, decimals=0): Returns (Tensor): rounded tensor of same type """ - return self.__class__(np.round(self, decimals=decimals)) + return type(self)(np.round(self, decimals=decimals)) @property def symmetrized(self): @@ -628,7 +628,7 @@ def merge(old, new) -> None: if not converged: max_diff = np.max(np.abs(self - test_new)) warnings.warn(f"Warning, populated tensor is not converged with max diff of {max_diff}") - return self.__class__(test_new) + return type(self)(test_new) def as_dict(self, voigt: bool = False) -> dict: """Serializes the tensor object. @@ -689,7 +689,7 @@ def zeroed(self, tol: float = 1e-3): Returns: TensorCollection where small values are set to 0. """ - return self.__class__([tensor.zeroed(tol) for tensor in self]) + return type(self)([tensor.zeroed(tol) for tensor in self]) def transform(self, symm_op): """Transforms TensorCollection with a symmetry operation. @@ -699,7 +699,7 @@ def transform(self, symm_op): Returns: TensorCollection. """ - return self.__class__([tensor.transform(symm_op) for tensor in self]) + return type(self)([tensor.transform(symm_op) for tensor in self]) def rotate(self, matrix, tol: float = 1e-3): """Rotates TensorCollection. @@ -710,12 +710,12 @@ def rotate(self, matrix, tol: float = 1e-3): Returns: TensorCollection. """ - return self.__class__([tensor.rotate(matrix, tol) for tensor in self]) + return type(self)([tensor.rotate(matrix, tol) for tensor in self]) @property def symmetrized(self): """TensorCollection where all tensors are symmetrized.""" - return self.__class__([tensor.symmetrized for tensor in self]) + return type(self)([tensor.symmetrized for tensor in self]) def is_symmetric(self, tol: float = 1e-5): """:param tol: tolerance @@ -734,7 +734,7 @@ def fit_to_structure(self, structure: Structure, symprec: float = 0.1): Returns: TensorCollection. """ - return self.__class__([tensor.fit_to_structure(structure, symprec) for tensor in self]) + return type(self)([tensor.fit_to_structure(structure, symprec) for tensor in self]) def is_fit_to_structure(self, structure: Structure, tol: float = 1e-2): """:param structure: Structure @@ -785,7 +785,7 @@ def convert_to_ieee(self, structure: Structure, initial_fit=True, refine_rotatio Returns: TensorCollection. """ - return self.__class__([tensor.convert_to_ieee(structure, initial_fit, refine_rotation) for tensor in self]) + return type(self)([tensor.convert_to_ieee(structure, initial_fit, refine_rotation) for tensor in self]) def round(self, *args, **kwargs): """Round all tensors. @@ -796,12 +796,12 @@ def round(self, *args, **kwargs): Returns: TensorCollection. """ - return self.__class__([tensor.round(*args, **kwargs) for tensor in self]) + return type(self)([tensor.round(*args, **kwargs) for tensor in self]) @property def voigt_symmetrized(self): """TensorCollection where all tensors are voigt symmetrized.""" - return self.__class__([tensor.voigt_symmetrized for tensor in self]) + return type(self)([tensor.voigt_symmetrized for tensor in self]) def as_dict(self, voigt=False): """:param voigt: Whether to use Voigt form. diff --git a/pymatgen/core/units.py b/pymatgen/core/units.py index 40fe7f41af6..406051c62a8 100644 --- a/pymatgen/core/units.py +++ b/pymatgen/core/units.py @@ -520,7 +520,7 @@ def __add__(self, other): if other.unit != self.unit: other = other.to(self.unit) - return self.__class__(np.array(self) + np.array(other), unit_type=self.unit_type, unit=self.unit) + return type(self)(np.array(self) + np.array(other), unit_type=self.unit_type, unit=self.unit) def __sub__(self, other): if hasattr(other, "unit_type"): @@ -530,7 +530,7 @@ def __sub__(self, other): if other.unit != self.unit: other = other.to(self.unit) - return self.__class__(np.array(self) - np.array(other), unit_type=self.unit_type, unit=self.unit) + return type(self)(np.array(self) - np.array(other), unit_type=self.unit_type, unit=self.unit) def __mul__(self, other): # TODO Here we have the most important difference between FloatWithUnit and @@ -544,31 +544,31 @@ def __mul__(self, other): # bit misleading. # Same protocol for __div__ if not hasattr(other, "unit_type"): - return self.__class__( + return type(self)( np.array(self) * np.array(other), unit_type=self._unit_type, unit=self._unit, ) # Cannot use super since it returns an instance of self.__class__ # while here we want a bare numpy array. - return self.__class__(np.array(self).__mul__(np.array(other)), unit=self.unit * other.unit) + return type(self)(np.array(self).__mul__(np.array(other)), unit=self.unit * other.unit) def __rmul__(self, other): if not hasattr(other, "unit_type"): - return self.__class__( + return type(self)( np.array(self) * np.array(other), unit_type=self._unit_type, unit=self._unit, ) - return self.__class__(np.array(self) * np.array(other), unit=self.unit * other.unit) + return type(self)(np.array(self) * np.array(other), unit=self.unit * other.unit) def __truediv__(self, other): if not hasattr(other, "unit_type"): - return self.__class__(np.array(self) / np.array(other), unit_type=self._unit_type, unit=self._unit) - return self.__class__(np.array(self) / np.array(other), unit=self.unit / other.unit) + return type(self)(np.array(self) / np.array(other), unit_type=self._unit_type, unit=self._unit) + return type(self)(np.array(self) / np.array(other), unit=self.unit / other.unit) def __neg__(self): - return self.__class__(-np.array(self), unit_type=self.unit_type, unit=self.unit) + return type(self)(-np.array(self), unit_type=self.unit_type, unit=self.unit) def to(self, new_unit): """Conversion to a new_unit. @@ -585,7 +585,7 @@ def to(self, new_unit): >>> e.to("eV") array([ 27.21138386, 29.93252225]) eV """ - return self.__class__( + return type(self)( np.array(self) * self.unit.get_conversion_factor(new_unit), unit_type=self.unit_type, unit=new_unit, diff --git a/pymatgen/io/abinit/abitimer.py b/pymatgen/io/abinit/abitimer.py index d4088bf8c05..22bf3462a3d 100644 --- a/pymatgen/io/abinit/abitimer.py +++ b/pymatgen/io/abinit/abitimer.py @@ -716,10 +716,11 @@ def get_dataframe(self, sort_key="wall_time", **kwargs): def get_values(self, keys): """Return a list of values associated to a particular list of keys.""" if isinstance(keys, str): - return [s.__dict__[keys] for s in self.sections] + return [sec.__dict__[keys] for sec in self.sections] + values = [] - for k in keys: - values.append([s.__dict__[k] for s in self.sections]) + for key in keys: + values.append([sec.__dict__[key] for sec in self.sections]) return values def names_and_values(self, key, minval=None, minfract=None, sorted=True): diff --git a/pymatgen/io/abinit/inputs.py b/pymatgen/io/abinit/inputs.py index d5dd488cf48..4036d9f9951 100644 --- a/pymatgen/io/abinit/inputs.py +++ b/pymatgen/io/abinit/inputs.py @@ -849,7 +849,7 @@ def to_str(self, post=None, with_structure=True, with_pseudos=True, exclude=None keys = sorted(k for k, v in self.items() if k not in exclude and v is not None) # Extract the items from the dict and add the geo variables at the end - items = [(k, self[k]) for k in keys] + items = [(key, self[key]) for key in keys] if with_structure: items.extend(list(aobj.structure_to_abivars(self.structure).items())) diff --git a/pymatgen/io/abinit/netcdf.py b/pymatgen/io/abinit/netcdf.py index 96ad0e9cd1d..a1c2e0685e5 100644 --- a/pymatgen/io/abinit/netcdf.py +++ b/pymatgen/io/abinit/netcdf.py @@ -218,19 +218,19 @@ def read_keys(self, keys, dict_cls=AttrDict, path="/"): Read a list of variables/dimensions from file. If a key is not present the corresponding entry in the output dictionary is set to None. """ - od = dict_cls() - for k in keys: + dct = dict_cls() + for key in keys: try: # Try to read a variable. - od[k] = self.read_value(k, path=path) + dct[key] = self.read_value(key, path=path) except self.Error: try: # Try to read a dimension. - od[k] = self.read_dimvalue(k, path=path) + dct[key] = self.read_dimvalue(key, path=path) except self.Error: - od[k] = None + dct[key] = None - return od + return dct class EtsfReader(NetcdfReader): diff --git a/pymatgen/io/abinit/pseudos.py b/pymatgen/io/abinit/pseudos.py index 290d880555e..8295808ce36 100644 --- a/pymatgen/io/abinit/pseudos.py +++ b/pymatgen/io/abinit/pseudos.py @@ -1625,8 +1625,8 @@ def __getitem__(self, Z): pseudos = [] for znum in iterator_from_slice(Z): pseudos.extend(self._pseudos_with_z[znum]) - return self.__class__(pseudos) - return self.__class__(self._pseudos_with_z[Z]) + return type(self)(pseudos) + return type(self)(self._pseudos_with_z[Z]) def __len__(self) -> int: return len(list(iter(self))) @@ -1782,7 +1782,7 @@ def select_symbols(self, symbols, ret_list=False): if ret_list: return pseudos - return self.__class__(pseudos) + return type(self)(pseudos) def get_pseudos_for_structure(self, structure: Structure): """ @@ -1839,11 +1839,11 @@ def sorted(self, attrname, reverse=False): attrs.append((i, a)) # Sort attrs, and build new table with sorted pseudos. - return self.__class__([self[a[0]] for a in sorted(attrs, key=lambda t: t[1], reverse=reverse)]) + return type(self)([self[a[0]] for a in sorted(attrs, key=lambda t: t[1], reverse=reverse)]) def sort_by_z(self): """Return a new PseudoTable with pseudos sorted by Z.""" - return self.__class__(sorted(self, key=lambda p: p.Z)) + return type(self)(sorted(self, key=lambda p: p.Z)) def select(self, condition) -> PseudoTable: """Select only those pseudopotentials for which condition is True. @@ -1855,7 +1855,7 @@ def select(self, condition) -> PseudoTable: Returns: PseudoTable: New PseudoTable instance with pseudos for which condition is True. """ - return self.__class__([p for p in self if condition(p)]) + return type(self)([p for p in self if condition(p)]) def with_dojo_report(self): """Select pseudos containing the DOJO_REPORT section. Return new class:`PseudoTable` object.""" @@ -1868,9 +1868,9 @@ def select_rows(self, rows): """ if not isinstance(rows, (list, tuple)): rows = [rows] - return self.__class__([p for p in self if p.element.row in rows]) + return type(self)([p for p in self if p.element.row in rows]) def select_family(self, family): """Return PseudoTable with element belonging to the specified family, e.g. family="alkaline".""" # e.g element.is_alkaline - return self.__class__([p for p in self if getattr(p.element, "is_" + family)]) + return type(self)([p for p in self if getattr(p.element, "is_" + family)]) diff --git a/pymatgen/io/lammps/data.py b/pymatgen/io/lammps/data.py index 1b7e084fc82..53ceb8301d8 100644 --- a/pymatgen/io/lammps/data.py +++ b/pymatgen/io/lammps/data.py @@ -302,7 +302,7 @@ def structure(self) -> Structure: if "nx" in atoms.columns: atoms = atoms.drop(["nx", "ny", "nz"], axis=1) atoms["molecule-ID"] = 1 - ld_copy = self.__class__(self.box, masses, atoms) + ld_copy = type(self)(self.box, masses, atoms) topologies = ld_copy.disassemble()[-1] molecule = topologies[0].sites coords = molecule.cart_coords - np.array(self.box.bounds)[:, 0] diff --git a/pymatgen/io/vasp/inputs.py b/pymatgen/io/vasp/inputs.py index 0f70dc20bcb..c81407b3d1f 100644 --- a/pymatgen/io/vasp/inputs.py +++ b/pymatgen/io/vasp/inputs.py @@ -58,8 +58,7 @@ class Poscar(MSONable): - """ - Object for representing the data in a POSCAR or CONTCAR file. + """Object for representing the data in a POSCAR or CONTCAR file. Attributes: structure: Associated Structure. @@ -682,8 +681,8 @@ def __init__(self, params: dict[str, Any] | None = None): params.get("LSORBIT") or params.get("LNONCOLLINEAR") ): val = [] - for i in range(len(params["MAGMOM"]) // 3): - val.append(params["MAGMOM"][i * 3 : (i + 1) * 3]) + for idx in range(len(params["MAGMOM"]) // 3): + val.append(params["MAGMOM"][idx * 3 : (idx + 1) * 3]) params["MAGMOM"] = val self.update(params) @@ -707,16 +706,20 @@ def as_dict(self) -> dict: return dct @classmethod - def from_dict(cls, d) -> Incar: + def from_dict(cls, dct: dict[str, Any]) -> Incar: """ - :param d: Dict representation. + Args: + dct (dict): Serialized Incar Returns: Incar """ - if d.get("MAGMOM") and isinstance(d["MAGMOM"][0], dict): - d["MAGMOM"] = [Magmom.from_dict(m) for m in d["MAGMOM"]] - return Incar({k: v for k, v in d.items() if k not in ("@module", "@class")}) + if dct.get("MAGMOM") and isinstance(dct["MAGMOM"][0], dict): + dct["MAGMOM"] = [Magmom.from_dict(m) for m in dct["MAGMOM"]] + return Incar({k: v for k, v in dct.items() if k not in ("@module", "@class")}) + + def copy(self): + return type(self)(self) def get_str(self, sort_keys: bool = False, pretty: bool = False) -> str: """ @@ -730,30 +733,28 @@ def get_str(self, sort_keys: bool = False, pretty: bool = False) -> str: pretty (bool): Set to True for pretty aligned output. Defaults to False. """ - keys = list(self) - if sort_keys: - keys = sorted(keys) + keys = sorted(self) if sort_keys else list(self) lines = [] - for k in keys: - if k == "MAGMOM" and isinstance(self[k], list): + for key in keys: + if key == "MAGMOM" and isinstance(self[key], list): value = [] - if isinstance(self[k][0], (list, Magmom)) and (self.get("LSORBIT") or self.get("LNONCOLLINEAR")): - value.append(" ".join(str(i) for j in self[k] for i in j)) + if isinstance(self[key][0], (list, Magmom)) and (self.get("LSORBIT") or self.get("LNONCOLLINEAR")): + value.append(" ".join(str(i) for j in self[key] for i in j)) elif self.get("LSORBIT") or self.get("LNONCOLLINEAR"): - for m, g in itertools.groupby(self[k]): + for m, g in itertools.groupby(self[key]): value.append(f"3*{len(tuple(g))}*{m}") else: # float() to ensure backwards compatibility between # float magmoms and Magmom objects - for m, g in itertools.groupby(self[k], key=float): + for m, g in itertools.groupby(self[key], key=float): value.append(f"{len(tuple(g))}*{m}") - lines.append([k, " ".join(value)]) - elif isinstance(self[k], list): - lines.append([k, " ".join(map(str, self[k]))]) + lines.append([key, " ".join(value)]) + elif isinstance(self[key], list): + lines.append([key, " ".join(map(str, self[key]))]) else: - lines.append([k, self[k]]) + lines.append([key, self[key]]) if pretty: return str(tabulate([[line[0], "=", line[1]] for line in lines], tablefmt="plain")) @@ -1327,6 +1328,14 @@ def automatic_linemode(divisions, ibz): num_kpts=int(divisions), ) + def copy(self): + return self.from_dict(self.as_dict()) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Kpoints): + return NotImplemented + return self.as_dict() == other.as_dict() + @classmethod def from_file(cls, filename): """ @@ -1395,8 +1404,8 @@ def from_str(cls, string): kpts = [] labels = [] patt = re.compile(r"([e0-9.\-]+)\s+([e0-9.\-]+)\s+([e0-9.\-]+)\s*!*\s*(.*)") - for i in range(4, len(lines)): - line = lines[i] + for idx in range(4, len(lines)): + line = lines[idx] m = patt.match(line) if m: kpts.append([float(m.group(1)), float(m.group(2)), float(m.group(3))]) @@ -1419,8 +1428,8 @@ def from_str(cls, string): tet_weight = 0 tet_connections = None - for i in range(3, 3 + num_kpts): - tokens = lines[i].split() + for idx in range(3, 3 + num_kpts): + tokens = lines[idx].split() kpts.append([float(j) for j in tokens[0:3]]) kpts_weights.append(float(tokens[3])) if len(tokens) > 4: @@ -1434,8 +1443,8 @@ def from_str(cls, string): tet_number = int(tokens[0]) tet_weight = float(tokens[1]) tet_connections = [] - for i in range(5 + num_kpts, 5 + num_kpts + tet_number): - tokens = lines[i].split() + for idx in range(5 + num_kpts, 5 + num_kpts + tet_number): + tokens = lines[idx].split() tet_connections.append((int(tokens[0]), [int(tokens[j]) for j in range(1, 5)])) except IndexError: pass @@ -1761,7 +1770,7 @@ def __init__(self, data: str, symbol: str | None = None) -> None: ) def __str__(self) -> str: - return self.data + "\n" + return f"{self.data}\n" @property def electron_configuration(self) -> list[tuple[int, str, int]] | None: @@ -1780,8 +1789,7 @@ def electron_configuration(self) -> list[tuple[int, str, int]] | None: return config def write_file(self, filename: str) -> None: - """ - Write PotcarSingle to a file. + """Write PotcarSingle to a file. Args: filename (str): Filename to write to. @@ -1789,10 +1797,22 @@ def write_file(self, filename: str) -> None: with zopen(filename, mode="wt") as file: file.write(str(self)) + def __eq__(self, other: object) -> bool: + if not isinstance(other, PotcarSingle): + return NotImplemented + return self.data == other.data and self.keywords == other.keywords + + def copy(self) -> PotcarSingle: + """Returns a copy of the PotcarSingle. + + Returns: + PotcarSingle + """ + return PotcarSingle(self.data, symbol=self.symbol) + @classmethod def from_file(cls, filename: str) -> PotcarSingle: - """ - Reads PotcarSingle from file. + """Reads PotcarSingle from file. :param filename: Filename. @@ -1813,8 +1833,7 @@ def from_file(cls, filename: str) -> PotcarSingle: @classmethod def from_symbol_and_functional(cls, symbol: str, functional: str | None = None): - """ - Makes a PotcarSingle from a symbol and functional. + """Makes a PotcarSingle from a symbol and functional. Args: symbol (str): Symbol, e.g., Li_sv @@ -2335,7 +2354,6 @@ def __getattr__(self, attr: str) -> Any: def __repr__(self) -> str: cls_name = type(self).__name__ symbol, functional = self.symbol, self.functional - TITEL, VRHFIN = self.keywords["TITEL"], self.keywords["VRHFIN"] TITEL, VRHFIN, n_valence_elec = (self.keywords.get(key) for key in ("TITEL", "VRHFIN", "ZVAL")) return f"{cls_name}({symbol=}, {functional=}, {TITEL=}, {VRHFIN=}, {n_valence_elec=:.0f})" @@ -2416,7 +2434,7 @@ def __init__( symbols: Sequence[str] | None = None, functional: str | None = None, sym_potcar_map: dict[str, str] | None = None, - ): + ) -> None: """ Args: symbols (list[str]): Element symbols for POTCAR. This should correspond @@ -2439,8 +2457,8 @@ def __init__( if symbols is not None: self.set_symbols(symbols, functional, sym_potcar_map) - def __iter__(self) -> Iterator[PotcarSingle]: # __iter__ only needed to supply type hint - # so for psingle in Potcar() is correctly inferred as PotcarSingle + def __iter__(self) -> Iterator[PotcarSingle]: # boilerplate code. only here to supply + # type hint so `for psingle in Potcar()` is correctly inferred as PotcarSingle return super().__iter__() def as_dict(self): @@ -2541,7 +2559,15 @@ def set_symbols( class VaspInput(dict, MSONable): """Class to contain a set of vasp input objects corresponding to a run.""" - def __init__(self, incar, kpoints, poscar, potcar, optional_files=None, **kwargs): + def __init__( + self, + incar: Incar, + kpoints: Kpoints | None, + poscar: Poscar, + potcar: Potcar | None, + optional_files: dict[PathLike, object] | None = None, + **kwargs, + ) -> None: """ Initializes a VaspInput object with the given input files. @@ -2562,19 +2588,19 @@ def __init__(self, incar, kpoints, poscar, potcar, optional_files=None, **kwargs def __str__(self): output = [] - for k, v in self.items(): - output.extend((k, str(v), "")) + for key, val in self.items(): + output.extend((key, str(val), "")) return "\n".join(output) def as_dict(self): """MSONable dict.""" - dct = {k: v.as_dict() for k, v in self.items()} + dct = {key: val.as_dict() for key, val in self.items()} dct["@module"] = type(self).__module__ dct["@class"] = type(self).__name__ return dct @classmethod - def from_dict(cls, d): + def from_dict(cls, dct): """ :param d: Dict representation. @@ -2582,13 +2608,13 @@ def from_dict(cls, d): VaspInput """ dec = MontyDecoder() - sub_d = {"optional_files": {}} - for k, v in d.items(): - if k in ["INCAR", "POSCAR", "POTCAR", "KPOINTS"]: - sub_d[k.lower()] = dec.process_decoded(v) - elif k not in ["@module", "@class"]: - sub_d["optional_files"][k] = dec.process_decoded(v) - return cls(**sub_d) + sub_dct = {"optional_files": {}} + for key, val in dct.items(): + if key in ["INCAR", "POSCAR", "POTCAR", "KPOINTS"]: + sub_dct[key.lower()] = dec.process_decoded(val) + elif key not in ["@module", "@class"]: + sub_dct["optional_files"][key] = dec.process_decoded(val) + return cls(**sub_dct) def write_input(self, output_dir=".", make_dir_if_not_present=True): """ @@ -2620,7 +2646,7 @@ def from_directory(cls, input_dir, optional_files=None): dict of {filename: Object type}. Object type must have a static method from_file. """ - sub_d = {} + sub_dct = {} for fname, ftype in [ ("INCAR", Incar), ("KPOINTS", Kpoints), @@ -2629,15 +2655,21 @@ def from_directory(cls, input_dir, optional_files=None): ]: try: full_zpath = zpath(os.path.join(input_dir, fname)) - sub_d[fname.lower()] = ftype.from_file(full_zpath) + sub_dct[fname.lower()] = ftype.from_file(full_zpath) except FileNotFoundError: # handle the case where there is no KPOINTS file - sub_d[fname.lower()] = None + sub_dct[fname.lower()] = None - sub_d["optional_files"] = {} + sub_dct["optional_files"] = {} if optional_files is not None: for fname, ftype in optional_files.items(): - sub_d["optional_files"][fname] = ftype.from_file(os.path.join(input_dir, fname)) - return cls(**sub_d) + sub_dct["optional_files"][fname] = ftype.from_file(os.path.join(input_dir, fname)) + return cls(**sub_dct) + + def copy(self, deep: bool = True): + """Deep copy of VaspInput.""" + if deep: + return self.from_dict(self.as_dict()) + return type(self)(**{key.lower(): val for key, val in self.items()}) def run_vasp( self, diff --git a/pymatgen/symmetry/structure.py b/pymatgen/symmetry/structure.py index a7aa4913c63..932daace75d 100644 --- a/pymatgen/symmetry/structure.py +++ b/pymatgen/symmetry/structure.py @@ -96,16 +96,13 @@ def __repr__(self) -> str: return str(self) def __str__(self) -> str: - def to_str(x): - return f"{x:>10.6f}" - outs = [ "SymmetrizedStructure", f"Full Formula ({self.composition.formula})", f"Reduced Formula: {self.composition.reduced_formula}", f"Spacegroup: {self.spacegroup.int_symbol} ({self.spacegroup.int_number})", - f"abc : {' '.join(to_str(val) for val in self.lattice.abc)}", - f"angles: {' '.join(to_str(val) for val in self.lattice.angles)}", + f"abc : {' '.join(f'{val:>10.6f}' for val in self.lattice.abc)}", + f"angles: {' '.join(f'{val:>10.6f}' for val in self.lattice.angles)}", ] if self._charge: @@ -117,10 +114,10 @@ def to_str(x): for idx, sites in enumerate(self.equivalent_sites): site = sites[0] row = [str(idx), site.species_string] - row.extend([to_str(j) for j in site.frac_coords]) + row.extend([f"{j:>10.6f}" for j in site.frac_coords]) row.append(self.wyckoff_symbols[idx]) - for k in keys: - row.append(props[k][idx]) + for key in keys: + row.append(props[key][idx]) data.append(row) outs.append(tabulate(data, headers=["#", "SP", "a", "b", "c", "Wyckoff", *keys])) return "\n".join(outs) diff --git a/tests/io/vasp/test_inputs.py b/tests/io/vasp/test_inputs.py index f385a439404..5ba7f8d67af 100644 --- a/tests/io/vasp/test_inputs.py +++ b/tests/io/vasp/test_inputs.py @@ -59,8 +59,7 @@ def _mock_complete_potcar_summary_stats(monkeypatch: MonkeyPatch) -> None: class TestPoscar(PymatgenTest): def test_init(self): - filepath = f"{TEST_FILES_DIR}/POSCAR" - comp = Structure.from_file(filepath).composition + comp = Structure.from_file(f"{TEST_FILES_DIR}/POSCAR").composition assert comp == Composition("Fe4P4O16") # VASP 4 type with symbols at the end. @@ -458,7 +457,7 @@ def test_vasp_6_4_2_format(self): with open(f"{TEST_FILES_DIR}/POSCAR.LiFePO4") as file: for idx, line in enumerate(file): if idx == 5: - line = " ".join([x + "/" for x in line.split()]) + "\n" + line = " ".join(f"{x}/" for x in line.split()) + "\n" poscar_str += line poscar = Poscar.from_str(poscar_str) assert poscar.structure.formula == "Li4 Fe4 P4 O16" @@ -466,8 +465,7 @@ def test_vasp_6_4_2_format(self): class TestIncar(PymatgenTest): def setUp(self): - file_name = f"{TEST_FILES_DIR}/INCAR" - self.incar = Incar.from_file(file_name) + self.incar = Incar.from_file(f"{TEST_FILES_DIR}/INCAR") def test_init(self): incar = self.incar @@ -476,6 +474,15 @@ def test_init(self): assert float(incar["EDIFF"]) == 1e-4, "Wrong EDIFF" assert isinstance(incar["LORBIT"], int) + def test_copy(self): + incar2 = self.incar.copy() + assert isinstance(incar2, Incar), f"Expected Incar, got {type(incar2)}" + assert incar2 == self.incar + # modify incar2 and check that incar1 is not modified + incar2["LDAU"] = "F" + assert incar2["LDAU"] is False + assert self.incar.get("LDAU") is None + def test_diff(self): filepath1 = f"{TEST_FILES_DIR}/INCAR" incar1 = Incar.from_file(filepath1) @@ -910,24 +917,42 @@ def test_as_dict_from_dict(self): def test_kpt_bands_as_dict_from_dict(self): file_name = f"{TEST_FILES_DIR}/KPOINTS.band" - k = Kpoints.from_file(file_name) - dct = k.as_dict() + kpts = Kpoints.from_file(file_name) + dct = kpts.as_dict() json.dumps(dct) # This doesn't work k2 = Kpoints.from_dict(dct) - assert k.kpts == k2.kpts - assert k.style == k2.style - assert k.kpts_shift == k2.kpts_shift - assert k.num_kpts == k2.num_kpts + assert kpts.kpts == k2.kpts + assert kpts.style == k2.style + assert kpts.kpts_shift == k2.kpts_shift + assert kpts.num_kpts == k2.num_kpts def test_pickle(self): - k = Kpoints.gamma_automatic() - pickle.dumps(k) + kpts = Kpoints.gamma_automatic() + pickle.dumps(kpts) + + def test_eq(self): + auto_g_kpts = Kpoints.gamma_automatic() + assert auto_g_kpts == auto_g_kpts + assert auto_g_kpts == Kpoints.gamma_automatic() + file_kpts = Kpoints.from_file(f"{TEST_FILES_DIR}/KPOINTS") + assert file_kpts == Kpoints.from_file(f"{TEST_FILES_DIR}/KPOINTS") + assert auto_g_kpts != file_kpts + auto_m_kpts = Kpoints.monkhorst_automatic([2, 2, 2], [0, 0, 0]) + assert auto_m_kpts == Kpoints.monkhorst_automatic([2, 2, 2], [0, 0, 0]) + assert auto_g_kpts != auto_m_kpts + + def test_copy(self): + kpts = Kpoints.gamma_automatic() + kpt_copy = kpts.copy() + assert kpts == kpt_copy + kpt_copy.style = Kpoints.supported_modes.Monkhorst + assert kpts != kpt_copy def test_automatic_kpoint(self): # struct = PymatgenTest.get_structure("Li2O") - p = Poscar.from_str( + poscar = Poscar.from_str( """Al1 1.0 2.473329 0.000000 1.427977 @@ -938,7 +963,7 @@ def test_automatic_kpoint(self): direct 0.000000 0.000000 0.000000 Al""" ) - kpoints = Kpoints.automatic_density(p.structure, 1000) + kpoints = Kpoints.automatic_density(poscar.structure, 1000) assert_allclose(kpoints.kpts[0], [10, 10, 10]) def test_automatic_density_by_lengths(self): @@ -1177,6 +1202,17 @@ def test_sha256_file_hash(self): == "7bcf5ad80200e5d74ba63b45d87825b31e6cae2bcd03cebda2f1cbec9870c1cf" ) + def test_eq(self): + assert self.psingle_Mn_pv == self.psingle_Mn_pv + assert self.psingle_Fe == self.psingle_Fe + assert self.psingle_Mn_pv != self.psingle_Fe + assert self.psingle_Mn_pv != self.psingle_Fe_54 + + def test_copy(self): + psingle = self.psingle_Mn_pv.copy() + assert psingle == self.psingle_Mn_pv + assert psingle is not self.psingle_Mn_pv + class TestPotcar(PymatgenTest): def setUp(self): @@ -1273,6 +1309,22 @@ def test_write(self): assert {*os.listdir(tmp_dir)} == {"INCAR", "KPOINTS", "POSCAR", "POTCAR"} + def test_copy(self): + vasp_input2 = self.vasp_input.copy(deep=True) + assert isinstance(vasp_input2, VaspInput) + # make copy and original serialize to the same dict + assert vasp_input2.as_dict() == self.vasp_input.as_dict() + # modify the copy and make sure the original is not modified + vasp_input2["INCAR"]["NSW"] = 100 + assert vasp_input2["INCAR"]["NSW"] == 100 + assert self.vasp_input["INCAR"]["NSW"] == 99 + + # make a shallow copy and make sure the original is modified + vasp_input3 = self.vasp_input.copy(deep=False) + vasp_input3["INCAR"]["NSW"] = 100 + assert vasp_input3["INCAR"]["NSW"] == 100 + assert self.vasp_input["INCAR"]["NSW"] == 100 + def test_run_vasp(self): self.vasp_input.run_vasp(".", vasp_cmd=["cat", "INCAR"]) with open("vasp.out") as file: