Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VASP IO copy() methods #3602

Merged
merged 15 commits into from
Feb 8, 2024
2 changes: 1 addition & 1 deletion pymatgen/core/lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __format__(self, fmt_spec: str = "") -> str:

def copy(self):
"""Deep copy of self."""
return self.__class__(self.matrix.copy(), pbc=self.pbc)
return type(self)(self.matrix.copy(), pbc=self.pbc)

@property
def matrix(self) -> np.ndarray:
Expand Down
12 changes: 6 additions & 6 deletions pymatgen/core/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def copy(self):
Returns:
Copy of Spectrum object.
"""
return self.__class__(self.x, self.y, *self._args, **self._kwargs)
return type(self)(self.x, self.y, *self._args, **self._kwargs)

def __add__(self, other):
"""Add two Spectrum object together. Checks that x scales are the same.
Expand All @@ -150,7 +150,7 @@ def __add__(self, other):
"""
if not all(np.equal(self.x, other.x)):
raise ValueError("X axis values are not compatible!")
return self.__class__(self.x, self.y + other.y, *self._args, **self._kwargs)
return type(self)(self.x, self.y + other.y, *self._args, **self._kwargs)

def __sub__(self, other):
"""Subtract one Spectrum object from another. Checks that x scales are
Expand All @@ -165,7 +165,7 @@ def __sub__(self, other):
"""
if not all(np.equal(self.x, other.x)):
raise ValueError("X axis values are not compatible!")
return self.__class__(self.x, self.y - other.y, *self._args, **self._kwargs)
return type(self)(self.x, self.y - other.y, *self._args, **self._kwargs)

def __mul__(self, other):
"""Scale the Spectrum's y values.
Expand All @@ -176,7 +176,7 @@ def __mul__(self, other):
Returns:
Spectrum object with y values scaled
"""
return self.__class__(self.x, other * self.y, *self._args, **self._kwargs)
return type(self)(self.x, other * self.y, *self._args, **self._kwargs)

__rmul__ = __mul__

Expand All @@ -189,7 +189,7 @@ def __truediv__(self, other):
Returns:
Spectrum object with y values divided
"""
return self.__class__(self.x, self.y.__truediv__(other), *self._args, **self._kwargs)
return type(self)(self.x, self.y.__truediv__(other), *self._args, **self._kwargs)

def __floordiv__(self, other):
"""True division of y.
Expand All @@ -200,7 +200,7 @@ def __floordiv__(self, other):
Returns:
Spectrum object with y values divided
"""
return self.__class__(self.x, self.y.__floordiv__(other), *self._args, **self._kwargs)
return type(self)(self.x, self.y.__floordiv__(other), *self._args, **self._kwargs)

__div__ = __truediv__

Expand Down
18 changes: 8 additions & 10 deletions pymatgen/core/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2083,7 +2083,7 @@ def get_reduced_structure(self, reduction_algo: Literal["niggli", "LLL"] = "nigg
raise ValueError(f"Invalid {reduction_algo=}")

if reduced_latt != self.lattice:
return self.__class__(
return type(self)(
reduced_latt,
self.species_and_occu,
self.cart_coords, # type: ignore
Expand Down Expand Up @@ -2124,7 +2124,7 @@ def copy(self, site_properties=None, sanitize=False, properties=None) -> Structu
if properties:
props.update(properties)
if not sanitize:
return self.__class__(
return type(self)(
self._lattice,
self.species_and_occu,
self.frac_coords,
Expand Down Expand Up @@ -2265,9 +2265,7 @@ def interpolate(
else:
lattice = self.lattice
fcoords = start_coords + x * vec
structs.append(
self.__class__(lattice, sp, fcoords, site_properties=self.site_properties, labels=self.labels)
)
structs.append(type(self)(lattice, sp, fcoords, site_properties=self.site_properties, labels=self.labels))
return structs

def get_miller_index_from_site_indexes(self, site_ids, round_dp=4, verbose=True):
Expand Down Expand Up @@ -2526,11 +2524,11 @@ def to_str(x) -> str:
data = []
props = self.site_properties
keys = sorted(props)
for i, site in enumerate(self):
row = [str(i), site.species_string]
for idx, site in enumerate(self):
row = [str(idx), site.species_string]
row.extend([to_str(j) for j in site.frac_coords])
for k in keys:
row.append(props[k][i])
for key in keys:
row.append(props[key][idx])
data.append(row)
outs.append(
tabulate(
Expand Down Expand Up @@ -3517,7 +3515,7 @@ def get_centered_molecule(self) -> IMolecule | Molecule:
"""
center = self.center_of_mass
new_coords = np.array(self.cart_coords) - center
return self.__class__(
return type(self)(
self.species_and_occu,
new_coords,
charge=self._charge,
Expand Down
22 changes: 11 additions & 11 deletions pymatgen/core/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def transform(self, symm_op):
Args:
symm_op (SymmOp): a symmetry operation to apply to the tensor
"""
return self.__class__(symm_op.transform_tensor(self))
return type(self)(symm_op.transform_tensor(self))

def rotate(self, matrix, tol: float = 1e-3):
"""Applies a rotation directly, and tests input matrix to ensure a valid
Expand Down Expand Up @@ -257,7 +257,7 @@ def round(self, decimals=0):
Returns (Tensor):
rounded tensor of same type
"""
return self.__class__(np.round(self, decimals=decimals))
return type(self)(np.round(self, decimals=decimals))

@property
def symmetrized(self):
Expand Down Expand Up @@ -628,7 +628,7 @@ def merge(old, new) -> None:
if not converged:
max_diff = np.max(np.abs(self - test_new))
warnings.warn(f"Warning, populated tensor is not converged with max diff of {max_diff}")
return self.__class__(test_new)
return type(self)(test_new)

def as_dict(self, voigt: bool = False) -> dict:
"""Serializes the tensor object.
Expand Down Expand Up @@ -689,7 +689,7 @@ def zeroed(self, tol: float = 1e-3):
Returns:
TensorCollection where small values are set to 0.
"""
return self.__class__([tensor.zeroed(tol) for tensor in self])
return type(self)([tensor.zeroed(tol) for tensor in self])

def transform(self, symm_op):
"""Transforms TensorCollection with a symmetry operation.
Expand All @@ -699,7 +699,7 @@ def transform(self, symm_op):
Returns:
TensorCollection.
"""
return self.__class__([tensor.transform(symm_op) for tensor in self])
return type(self)([tensor.transform(symm_op) for tensor in self])

def rotate(self, matrix, tol: float = 1e-3):
"""Rotates TensorCollection.
Expand All @@ -710,12 +710,12 @@ def rotate(self, matrix, tol: float = 1e-3):
Returns:
TensorCollection.
"""
return self.__class__([tensor.rotate(matrix, tol) for tensor in self])
return type(self)([tensor.rotate(matrix, tol) for tensor in self])

@property
def symmetrized(self):
"""TensorCollection where all tensors are symmetrized."""
return self.__class__([tensor.symmetrized for tensor in self])
return type(self)([tensor.symmetrized for tensor in self])

def is_symmetric(self, tol: float = 1e-5):
""":param tol: tolerance
Expand All @@ -734,7 +734,7 @@ def fit_to_structure(self, structure: Structure, symprec: float = 0.1):
Returns:
TensorCollection.
"""
return self.__class__([tensor.fit_to_structure(structure, symprec) for tensor in self])
return type(self)([tensor.fit_to_structure(structure, symprec) for tensor in self])

def is_fit_to_structure(self, structure: Structure, tol: float = 1e-2):
""":param structure: Structure
Expand Down Expand Up @@ -785,7 +785,7 @@ def convert_to_ieee(self, structure: Structure, initial_fit=True, refine_rotatio
Returns:
TensorCollection.
"""
return self.__class__([tensor.convert_to_ieee(structure, initial_fit, refine_rotation) for tensor in self])
return type(self)([tensor.convert_to_ieee(structure, initial_fit, refine_rotation) for tensor in self])

def round(self, *args, **kwargs):
"""Round all tensors.
Expand All @@ -796,12 +796,12 @@ def round(self, *args, **kwargs):
Returns:
TensorCollection.
"""
return self.__class__([tensor.round(*args, **kwargs) for tensor in self])
return type(self)([tensor.round(*args, **kwargs) for tensor in self])

@property
def voigt_symmetrized(self):
"""TensorCollection where all tensors are voigt symmetrized."""
return self.__class__([tensor.voigt_symmetrized for tensor in self])
return type(self)([tensor.voigt_symmetrized for tensor in self])

def as_dict(self, voigt=False):
""":param voigt: Whether to use Voigt form.
Expand Down
20 changes: 10 additions & 10 deletions pymatgen/core/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def __add__(self, other):
if other.unit != self.unit:
other = other.to(self.unit)

return self.__class__(np.array(self) + np.array(other), unit_type=self.unit_type, unit=self.unit)
return type(self)(np.array(self) + np.array(other), unit_type=self.unit_type, unit=self.unit)

def __sub__(self, other):
if hasattr(other, "unit_type"):
Expand All @@ -530,7 +530,7 @@ def __sub__(self, other):
if other.unit != self.unit:
other = other.to(self.unit)

return self.__class__(np.array(self) - np.array(other), unit_type=self.unit_type, unit=self.unit)
return type(self)(np.array(self) - np.array(other), unit_type=self.unit_type, unit=self.unit)

def __mul__(self, other):
# TODO Here we have the most important difference between FloatWithUnit and
Expand All @@ -544,31 +544,31 @@ def __mul__(self, other):
# bit misleading.
# Same protocol for __div__
if not hasattr(other, "unit_type"):
return self.__class__(
return type(self)(
np.array(self) * np.array(other),
unit_type=self._unit_type,
unit=self._unit,
)
# Cannot use super since it returns an instance of self.__class__
# while here we want a bare numpy array.
return self.__class__(np.array(self).__mul__(np.array(other)), unit=self.unit * other.unit)
return type(self)(np.array(self).__mul__(np.array(other)), unit=self.unit * other.unit)

def __rmul__(self, other):
if not hasattr(other, "unit_type"):
return self.__class__(
return type(self)(
np.array(self) * np.array(other),
unit_type=self._unit_type,
unit=self._unit,
)
return self.__class__(np.array(self) * np.array(other), unit=self.unit * other.unit)
return type(self)(np.array(self) * np.array(other), unit=self.unit * other.unit)

def __truediv__(self, other):
if not hasattr(other, "unit_type"):
return self.__class__(np.array(self) / np.array(other), unit_type=self._unit_type, unit=self._unit)
return self.__class__(np.array(self) / np.array(other), unit=self.unit / other.unit)
return type(self)(np.array(self) / np.array(other), unit_type=self._unit_type, unit=self._unit)
return type(self)(np.array(self) / np.array(other), unit=self.unit / other.unit)

def __neg__(self):
return self.__class__(-np.array(self), unit_type=self.unit_type, unit=self.unit)
return type(self)(-np.array(self), unit_type=self.unit_type, unit=self.unit)

def to(self, new_unit):
"""Conversion to a new_unit.
Expand All @@ -585,7 +585,7 @@ def to(self, new_unit):
>>> e.to("eV")
array([ 27.21138386, 29.93252225]) eV
"""
return self.__class__(
return type(self)(
np.array(self) * self.unit.get_conversion_factor(new_unit),
unit_type=self.unit_type,
unit=new_unit,
Expand Down
7 changes: 4 additions & 3 deletions pymatgen/io/abinit/abitimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,10 +716,11 @@ def get_dataframe(self, sort_key="wall_time", **kwargs):
def get_values(self, keys):
"""Return a list of values associated to a particular list of keys."""
if isinstance(keys, str):
return [s.__dict__[keys] for s in self.sections]
return [sec.__dict__[keys] for sec in self.sections]

values = []
for k in keys:
values.append([s.__dict__[k] for s in self.sections])
for key in keys:
values.append([sec.__dict__[key] for sec in self.sections])
return values

def names_and_values(self, key, minval=None, minfract=None, sorted=True):
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/io/abinit/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,7 @@ def to_str(self, post=None, with_structure=True, with_pseudos=True, exclude=None
keys = sorted(k for k, v in self.items() if k not in exclude and v is not None)

# Extract the items from the dict and add the geo variables at the end
items = [(k, self[k]) for k in keys]
items = [(key, self[key]) for key in keys]
if with_structure:
items.extend(list(aobj.structure_to_abivars(self.structure).items()))

Expand Down
12 changes: 6 additions & 6 deletions pymatgen/io/abinit/netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,19 +218,19 @@ def read_keys(self, keys, dict_cls=AttrDict, path="/"):
Read a list of variables/dimensions from file. If a key is not present the corresponding
entry in the output dictionary is set to None.
"""
od = dict_cls()
for k in keys:
dct = dict_cls()
for key in keys:
try:
# Try to read a variable.
od[k] = self.read_value(k, path=path)
dct[key] = self.read_value(key, path=path)
except self.Error:
try:
# Try to read a dimension.
od[k] = self.read_dimvalue(k, path=path)
dct[key] = self.read_dimvalue(key, path=path)
except self.Error:
od[k] = None
dct[key] = None

return od
return dct


class EtsfReader(NetcdfReader):
Expand Down
16 changes: 8 additions & 8 deletions pymatgen/io/abinit/pseudos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1625,8 +1625,8 @@ def __getitem__(self, Z):
pseudos = []
for znum in iterator_from_slice(Z):
pseudos.extend(self._pseudos_with_z[znum])
return self.__class__(pseudos)
return self.__class__(self._pseudos_with_z[Z])
return type(self)(pseudos)
return type(self)(self._pseudos_with_z[Z])

def __len__(self) -> int:
return len(list(iter(self)))
Expand Down Expand Up @@ -1782,7 +1782,7 @@ def select_symbols(self, symbols, ret_list=False):

if ret_list:
return pseudos
return self.__class__(pseudos)
return type(self)(pseudos)

def get_pseudos_for_structure(self, structure: Structure):
"""
Expand Down Expand Up @@ -1839,11 +1839,11 @@ def sorted(self, attrname, reverse=False):
attrs.append((i, a))

# Sort attrs, and build new table with sorted pseudos.
return self.__class__([self[a[0]] for a in sorted(attrs, key=lambda t: t[1], reverse=reverse)])
return type(self)([self[a[0]] for a in sorted(attrs, key=lambda t: t[1], reverse=reverse)])

def sort_by_z(self):
"""Return a new PseudoTable with pseudos sorted by Z."""
return self.__class__(sorted(self, key=lambda p: p.Z))
return type(self)(sorted(self, key=lambda p: p.Z))

def select(self, condition) -> PseudoTable:
"""Select only those pseudopotentials for which condition is True.
Expand All @@ -1855,7 +1855,7 @@ def select(self, condition) -> PseudoTable:
Returns:
PseudoTable: New PseudoTable instance with pseudos for which condition is True.
"""
return self.__class__([p for p in self if condition(p)])
return type(self)([p for p in self if condition(p)])

def with_dojo_report(self):
"""Select pseudos containing the DOJO_REPORT section. Return new class:`PseudoTable` object."""
Expand All @@ -1868,9 +1868,9 @@ def select_rows(self, rows):
"""
if not isinstance(rows, (list, tuple)):
rows = [rows]
return self.__class__([p for p in self if p.element.row in rows])
return type(self)([p for p in self if p.element.row in rows])

def select_family(self, family):
"""Return PseudoTable with element belonging to the specified family, e.g. family="alkaline"."""
# e.g element.is_alkaline
return self.__class__([p for p in self if getattr(p.element, "is_" + family)])
return type(self)([p for p in self if getattr(p.element, "is_" + family)])
2 changes: 1 addition & 1 deletion pymatgen/io/lammps/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def structure(self) -> Structure:
if "nx" in atoms.columns:
atoms = atoms.drop(["nx", "ny", "nz"], axis=1)
atoms["molecule-ID"] = 1
ld_copy = self.__class__(self.box, masses, atoms)
ld_copy = type(self)(self.box, masses, atoms)
topologies = ld_copy.disassemble()[-1]
molecule = topologies[0].sites
coords = molecule.cart_coords - np.array(self.box.bounds)[:, 0]
Expand Down
Loading
Loading