diff --git a/pymatgen/core/structure.py b/pymatgen/core/structure.py index 3604a7ae29c..500a689aaf5 100644 --- a/pymatgen/core/structure.py +++ b/pymatgen/core/structure.py @@ -232,6 +232,7 @@ def get_distance(self, i: int, j: int) -> float: Returns: Distance between sites at index i and index j. """ + raise NotImplementedError @property def distance_matrix(self) -> np.ndarray: @@ -426,20 +427,20 @@ def get_dihedral(self, i: int, j: int, k: int, l: int) -> float: # noqa: E741 """Returns dihedral angle specified by four sites. Args: - i: 1st site index - j: 2nd site index - k: 3rd site index - l: 4th site index + i (int): 1st site index + j (int): 2nd site index + k (int): 3rd site index + l (int): 4th site index Returns: Dihedral angle in degrees. """ - v1 = self[k].coords - self[l].coords - v2 = self[j].coords - self[k].coords - v3 = self[i].coords - self[j].coords - v23 = np.cross(v2, v3) - v12 = np.cross(v1, v2) - return math.degrees(math.atan2(np.linalg.norm(v2) * np.dot(v1, v23), np.dot(v12, v23))) + vec1 = self[k].coords - self[l].coords + vec2 = self[j].coords - self[k].coords + vec3 = self[i].coords - self[j].coords + vec23 = np.cross(vec2, vec3) + vec12 = np.cross(vec1, vec2) + return math.degrees(math.atan2(np.linalg.norm(vec2) * np.dot(vec1, vec23), np.dot(vec12, vec23))) def is_valid(self, tol: float = DISTANCE_TOLERANCE) -> bool: """True if SiteCollection does not contain atoms that are too close @@ -481,7 +482,7 @@ def from_file(cls, filename: str) -> None: """Reads in SiteCollection from a filename.""" raise NotImplementedError - def add_site_property(self, property_name: str, values: Sequence | np.ndarray) -> None: + def add_site_property(self, property_name: str, values: Sequence | np.ndarray) -> SiteCollection: """Adds a property to a site. Note: This is the preferred method for adding magnetic moments, selective dynamics, and related site-specific properties to a structure/molecule object. @@ -497,21 +498,31 @@ def add_site_property(self, property_name: str, values: Sequence | np.ndarray) - Raises: ValueError: if len(values) != number of sites. + + Returns: + SiteCollection: self with site property added. """ if len(values) != len(self): raise ValueError(f"{len(values)=} must equal sites in structure={len(self)}") for site, val in zip(self, values): site.properties[property_name] = val - def remove_site_property(self, property_name: str) -> None: + return self + + def remove_site_property(self, property_name: str) -> SiteCollection: """Removes a property to a site. Args: property_name (str): The name of the property to remove. + + Returns: + SiteCollection: self with property removed. """ for site in self: del site.properties[property_name] + return self + def replace_species( self, species_mapping: dict[SpeciesLike, SpeciesLike | dict[SpeciesLike, float]], in_place: bool = True ) -> SiteCollection: @@ -524,6 +535,9 @@ def replace_species( {Element('Si): {Element('Ge'): 0.75, Element('C'): 0.25} } will have .375 Ge and .125 C. in_place (bool): Whether to perform the substitution in place or modify a copy. Defaults to True. + + Returns: + SiteCollection: self or new SiteCollection (depending on in_place) with species replaced. """ site_coll = self if in_place else self.copy() sp_mapping = {get_el_sp(k): v for k, v in species_mapping.items()} @@ -548,7 +562,7 @@ def replace_species( return site_coll - def add_oxidation_state_by_element(self, oxidation_states: dict[str, float]) -> None: + def add_oxidation_state_by_element(self, oxidation_states: dict[str, float]) -> SiteCollection: """Add oxidation states. Args: @@ -557,6 +571,9 @@ def add_oxidation_state_by_element(self, oxidation_states: dict[str, float]) -> Raises: ValueError if oxidation states are not specified for all elements. + + Returns: + SiteCollection: self with oxidation states. """ missing = {el.symbol for el in self.composition} - {*oxidation_states} if missing: @@ -567,12 +584,20 @@ def add_oxidation_state_by_element(self, oxidation_states: dict[str, float]) -> new_sp[Species(el.symbol, oxidation_states[el.symbol])] = occu site.species = Composition(new_sp) - def add_oxidation_state_by_site(self, oxidation_states: list[float]) -> None: + return self + + def add_oxidation_state_by_site(self, oxidation_states: list[float]) -> SiteCollection: """Add oxidation states to a structure by site. Args: oxidation_states (list[float]): List of oxidation states. E.g. [1, 1, 1, 1, 2, 2, 2, 2, 5, 5, 5, 5, -2, -2, -2, -2] + + Raises: + ValueError if oxidation states are not specified for all sites. + + Returns: + SiteCollection: self with oxidation states. """ if len(oxidation_states) != len(self): raise ValueError( @@ -586,6 +611,8 @@ def add_oxidation_state_by_site(self, oxidation_states: list[float]) -> None: new_sp[Species(sym, ox)] = occu site.species = Composition(new_sp) + return self + def remove_oxidation_states(self) -> SiteCollection: """Removes oxidation states from a structure.""" for site in self: @@ -4115,13 +4142,12 @@ def apply_strain(self, strain: ArrayLike, inplace: bool = True) -> Structure: struct.lattice = new_lattice return struct - def sort(self, key: Callable | None = None, reverse: bool = False) -> None: + def sort(self, key: Callable | None = None, reverse: bool = False) -> Structure: """Sort a structure in place. The parameters have the same meaning as in - list.sort. By default, sites are sorted by the electronegativity of + list.sort(). By default, sites are sorted by the electronegativity of the species. The difference between this method and get_sorted_structure (which also works in IStructure) is that the - latter returns a new Structure, while this just sorts the Structure - in place. + latter returns a new Structure, while this modifies the original. Args: key: Specifies a function of one argument that is used to extract @@ -4129,14 +4155,18 @@ def sort(self, key: Callable | None = None, reverse: bool = False) -> None: default value is None (compare the elements directly). reverse (bool): If set to True, then the list elements are sorted as if each comparison were reversed. + + Returns: + Structure: Sorted structure. """ self._sites.sort(key=key, reverse=reverse) + return self def translate_sites( self, indices: int | Sequence[int], vector: ArrayLike, frac_coords: bool = True, to_unit_cell: bool = True - ) -> None: + ) -> Structure: """Translate specific sites by some vector, keeping the sites within the - unit cell. + unit cell. Modifies the structure in place. Args: indices: Integer or List of site indices on which to perform the @@ -4146,6 +4176,9 @@ def translate_sites( Cartesian coordinates. to_unit_cell (bool): Whether new sites are transformed to unit cell + + Returns: + Structure: self with translated sites. """ if not isinstance(indices, collections.abc.Iterable): indices = [indices] @@ -4160,6 +4193,8 @@ def translate_sites( f_coords = [np.mod(f, 1) if p else f for p, f in zip(self.lattice.pbc, f_coords)] self[idx].frac_coords = f_coords + return self + def rotate_sites( self, indices: list[int] | None = None, @@ -4167,8 +4202,9 @@ def rotate_sites( axis: ArrayLike | None = None, anchor: ArrayLike | None = None, to_unit_cell: bool = True, - ) -> None: - """Rotate specific sites by some angle around vector at anchor. + ) -> Structure: + """Rotate specific sites by some angle around vector at anchor. Modifies + the structure in place. Args: indices (list): List of site indices on which to perform the @@ -4176,8 +4212,10 @@ def rotate_sites( theta (float): Angle in radians axis (3x1 array): Rotation axis vector. anchor (3x1 array): Point of rotation. - to_unit_cell (bool): Whether new sites are transformed to unit - cell + to_unit_cell (bool): Whether new sites are transformed to unit cell + + Returns: + Structure: self with rotated sites. """ if indices is None: indices = list(range(len(self))) @@ -4209,9 +4247,11 @@ def rotate_sites( ) self[idx] = new_site - def perturb(self, distance: float, min_distance: float | None = None) -> None: + return self + + def perturb(self, distance: float, min_distance: float | None = None) -> Structure: """Performs a random perturbation of the sites in a structure to break - symmetries. + symmetries. Modifies the structure in place. Args: distance (float): Distance in angstroms by which to perturb each @@ -4220,6 +4260,9 @@ def perturb(self, distance: float, min_distance: float | None = None) -> None: be equal amplitude. If int or float, perturb each site a distance drawn from the uniform distribution between 'min_distance' and 'distance'. + + Returns: + Structure: self with perturbed sites. """ def get_rand_vec(): @@ -4234,6 +4277,8 @@ def get_rand_vec(): for idx in range(len(self._sites)): self.translate_sites([idx], get_rand_vec(), frac_coords=False) + return self + def make_supercell(self, scaling_matrix: ArrayLike, to_unit_cell: bool = True, in_place: bool = True) -> Structure: """Create a supercell. @@ -4270,16 +4315,21 @@ def make_supercell(self, scaling_matrix: ArrayLike, to_unit_cell: bool = True, i return struct - def scale_lattice(self, volume: float) -> None: + def scale_lattice(self, volume: float) -> Structure: """Performs a scaling of the lattice vectors so that length proportions and angles are preserved. Args: volume (float): New volume of the unit cell in A^3. + + Returns: + Structure: self with scaled lattice. """ self.lattice = self._lattice.scale(volume) - def merge_sites(self, tol: float = 0.01, mode: Literal["sum", "delete", "average"] = "sum") -> None: + return self + + def merge_sites(self, tol: float = 0.01, mode: Literal["sum", "delete", "average"] = "sum") -> Structure: """Merges sites (adding occupancies) within tol of each other. Removes site properties. @@ -4289,6 +4339,9 @@ def merge_sites(self, tol: float = 0.01, mode: Literal["sum", "delete", "average deleted. "sum" means the occupancies are summed for the sites. "average" means that the site is deleted but the properties are averaged Only first letter is considered. + + Returns: + Structure: self with merged sites. """ dist_mat = self.distance_matrix np.fill_diagonal(dist_mat, 0) @@ -4318,14 +4371,19 @@ def merge_sites(self, tol: float = 0.01, mode: Literal["sum", "delete", "average sites.append(PeriodicSite(species, coords, self.lattice, properties=props)) self._sites = sites + return self - def set_charge(self, new_charge: float = 0.0) -> None: + def set_charge(self, new_charge: float = 0.0) -> Structure: """Sets the overall structure charge. Args: new_charge (float): new charge to set + + Returns: + Structure: self with new charge set. """ self._charge = new_charge + return self def relax( self, @@ -4396,7 +4454,7 @@ def from_prototype(cls, prototype: str, species: Sequence, **kwargs) -> Structur specified. For example, if it is a cubic prototype, only a needs to be specified. Returns: - Structure + Structure: with given prototype and species. """ prototype = prototype.lower() try: @@ -4568,7 +4626,7 @@ def append( # type: ignore properties=properties, ) - def set_charge_and_spin(self, charge: float, spin_multiplicity: int | None = None) -> None: + def set_charge_and_spin(self, charge: float, spin_multiplicity: int | None = None) -> Molecule: """Set the charge and spin multiplicity. Args: @@ -4577,6 +4635,9 @@ def set_charge_and_spin(self, charge: float, spin_multiplicity: int | None = Non Defaults to None, which means that the spin multiplicity is set to 1 if the molecule has no unpaired electrons and to 2 if there are unpaired electrons. + + Returns: + Molecule: self with new charge and spin multiplicity set. """ self._charge = charge n_electrons = 0.0 @@ -4596,6 +4657,8 @@ def set_charge_and_spin(self, charge: float, spin_multiplicity: int | None = Non else: self._spin_multiplicity = 1 if n_electrons % 2 == 0 else 2 + return self + def insert( # type: ignore self, idx: int, @@ -4628,11 +4691,14 @@ def insert( # type: ignore return self - def remove_species(self, species: Sequence[SpeciesLike]) -> None: + def remove_species(self, species: Sequence[SpeciesLike]) -> Molecule: """Remove all occurrences of a species from a molecule. Args: species: Species to remove. + + Returns: + Molecule: self with species removed. """ new_sites = [] species = [get_el_sp(sp) for sp in species] @@ -4641,16 +4707,21 @@ def remove_species(self, species: Sequence[SpeciesLike]) -> None: if len(new_sp_occu) > 0: new_sites.append(Site(new_sp_occu, site.coords, properties=site.properties, label=site.label)) self.sites = new_sites + return self - def remove_sites(self, indices: Sequence[int]) -> None: + def remove_sites(self, indices: Sequence[int]) -> Molecule: """Delete sites with at indices. Args: indices: Sequence of indices of sites to delete. + + Returns: + Molecule: self with sites removed. """ self.sites = [self[idx] for idx in range(len(self)) if idx not in indices] + return self - def translate_sites(self, indices: Sequence[int] | None = None, vector: ArrayLike | None = None) -> None: + def translate_sites(self, indices: Sequence[int] | None = None, vector: ArrayLike | None = None) -> Molecule: """Translate specific sites by some vector, keeping the sites within the unit cell. @@ -4658,6 +4729,9 @@ def translate_sites(self, indices: Sequence[int] | None = None, vector: ArrayLik indices (list): List of site indices on which to perform the translation. vector (3x1 array): Translation vector for sites. + + Returns: + Molecule: self with translated sites. """ if indices is None: indices = range(len(self)) @@ -4667,6 +4741,7 @@ def translate_sites(self, indices: Sequence[int] | None = None, vector: ArrayLik site = self[idx] new_site = Site(site.species, site.coords + vector, properties=site.properties, label=site.label) self[idx] = new_site + return self def rotate_sites( self, @@ -4674,7 +4749,7 @@ def rotate_sites( theta: float = 0.0, axis: ArrayLike | None = None, anchor: ArrayLike | None = None, - ) -> None: + ) -> Molecule: """Rotate specific sites by some angle around vector at anchor. Args: @@ -4683,6 +4758,9 @@ def rotate_sites( theta (float): Angle in radians axis (3x1 array): Rotation axis vector. anchor (3x1 array): Point of rotation. + + Returns: + Molecule: self with rotated sites. """ if indices is None: indices = range(len(self)) @@ -4706,13 +4784,17 @@ def rotate_sites( new_site = Site(site.species, s, properties=site.properties, label=site.label) self[idx] = new_site - def perturb(self, distance: float) -> None: + return self + + def perturb(self, distance: float) -> Molecule: """Performs a random perturbation of the sites in a structure to break symmetries. Args: - distance (float): Distance in angstroms by which to perturb each - site. + distance (float): Distance in angstroms by which to perturb each site. + + Returns: + Molecule: self with perturbed sites. """ def get_rand_vec(): @@ -4724,11 +4806,16 @@ def get_rand_vec(): for idx in range(len(self)): self.translate_sites([idx], get_rand_vec()) - def apply_operation(self, symmop: SymmOp) -> None: + return self + + def apply_operation(self, symmop: SymmOp) -> Molecule: """Apply a symmetry operation to the molecule. Args: symmop (SymmOp): Symmetry operation to apply. + + Returns: + Molecule: self after symmetry operation. """ def operate_site(site): @@ -4737,7 +4824,9 @@ def operate_site(site): self.sites = [operate_site(site) for site in self] - def substitute(self, index: int, func_group: IMolecule | Molecule | str, bond_order: int = 1) -> None: + return self + + def substitute(self, index: int, func_group: IMolecule | Molecule | str, bond_order: int = 1) -> Molecule: """Substitute atom at index with a functional group. Args: @@ -4758,6 +4847,9 @@ def substitute(self, index: int, func_group: IMolecule | Molecule | str, bond_or bond_order (int): A specified bond order to calculate the bond length between the attached functional group and the nearest neighbor site. Defaults to 1. + + Returns: + Molecule: self after substitution. """ # Find the nearest neighbor that is not a terminal atom. all_non_terminal_nn = [] @@ -4823,6 +4915,7 @@ def substitute(self, index: int, func_group: IMolecule | Molecule | str, bond_or # group. del self[index] self._sites += list(functional_group[1:]) + return self def relax( self, diff --git a/pymatgen/core/surface.py b/pymatgen/core/surface.py index e51969a5f51..6f9ad1a6c67 100644 --- a/pymatgen/core/surface.py +++ b/pymatgen/core/surface.py @@ -415,8 +415,8 @@ def normal(self): @property def surface_area(self): """Calculates the surface area of the slab.""" - m = self.lattice.matrix - return np.linalg.norm(np.cross(m[0], m[1])) + matrix = self.lattice.matrix + return np.linalg.norm(np.cross(matrix[0], matrix[1])) @property def center_of_mass(self): @@ -424,7 +424,7 @@ def center_of_mass(self): weights = [s.species.weight for s in self] return np.average(self.frac_coords, weights=weights, axis=0) - def add_adsorbate_atom(self, indices, specie, distance) -> None: + def add_adsorbate_atom(self, indices, specie, distance) -> Slab: """Gets the structure of single atom adsorption. slab structure from the Slab class(in [0, 0, 1]). @@ -435,14 +435,19 @@ def add_adsorbate_atom(self, indices, specie, distance) -> None: specie (Species/Element/str): adsorbed atom species distance (float): between centers of the adsorbed atom and the given site in Angstroms. + + Returns: + Slab: self with adsorbed atom. """ - # Let's do the work in Cartesian coords - center = np.sum([self[i].coords for i in indices], axis=0) / len(indices) + # Let's work in Cartesian coords + center = np.sum([self[idx].coords for idx in indices], axis=0) / len(indices) coords = center + self.normal * distance / np.linalg.norm(self.normal) self.append(specie, coords, coords_are_cartesian=True) + return self + def __str__(self) -> str: def to_str(x) -> str: return f"{x:0.6f}" diff --git a/tests/core/test_structure.py b/tests/core/test_structure.py index 25fe9b35d81..e344e5d9c78 100644 --- a/tests/core/test_structure.py +++ b/tests/core/test_structure.py @@ -946,17 +946,17 @@ def test_not_hashable(self): _ = {self.struct: 1} def test_sort(self): - struct = self.struct - struct[0] = "F" - struct.sort() - assert struct[0].species_string == "Si" - assert struct[1].species_string == "F" - struct.sort(key=lambda site: site.species_string) - assert struct[0].species_string == "F" - assert struct[1].species_string == "Si" - struct.sort(key=lambda site: site.species_string, reverse=True) - assert struct[0].species_string == "Si" - assert struct[1].species_string == "F" + self.struct[0] = "F" + returned = self.struct.sort() + assert returned is self.struct + assert self.struct[0].species_string == "Si" + assert self.struct[1].species_string == "F" + self.struct.sort(key=lambda site: site.species_string) + assert self.struct[0].species_string == "F" + assert self.struct[1].species_string == "Si" + self.struct.sort(key=lambda site: site.species_string, reverse=True) + assert self.struct[0].species_string == "Si" + assert self.struct[1].species_string == "F" def test_replace_species(self): struct = self.struct @@ -1027,13 +1027,15 @@ def test_append_insert_remove_replace_substitute(self): def test_add_remove_site_property(self): struct = self.struct - struct.add_site_property("charge", [4.1, -5]) + returned = struct.add_site_property("charge", [4.1, -5]) + assert returned is struct assert struct[0].charge == 4.1 assert struct[1].charge == -5 struct.add_site_property("magmom", [3, 2]) assert struct[0].charge == 4.1 assert struct[0].magmom == 3 - struct.remove_site_property("magmom") + returned = struct.remove_site_property("magmom") + assert returned is struct with pytest.raises(AttributeError, match="attr='magmom' not found on PeriodicSite"): _ = struct[0].magmom @@ -1066,7 +1068,8 @@ def test_propertied_structure(self): def test_perturb(self): dist = 0.1 pre_perturbation_sites = self.struct.copy() - self.struct.perturb(distance=dist) + returned = self.struct.perturb(distance=dist) + assert returned is self.struct post_perturbation_sites = self.struct.sites for idx, site in enumerate(pre_perturbation_sites): @@ -1080,9 +1083,10 @@ def test_perturb(self): assert site.distance(post_perturbation_sites2[idx]) <= dist assert site.distance(post_perturbation_sites2[idx]) >= 0 - def test_add_oxidation_states_by_element(self): + def test_add_oxidation_state_by_element(self): oxidation_states = {"Si": -4} - self.struct.add_oxidation_state_by_element(oxidation_states) + returned = self.struct.add_oxidation_state_by_element(oxidation_states) + assert returned is self.struct for site in self.struct: for specie in site.species: assert specie.oxi_state == oxidation_states[specie.symbol], "Wrong oxidation state assigned!" @@ -1091,7 +1095,8 @@ def test_add_oxidation_states_by_element(self): self.struct.add_oxidation_state_by_element(oxidation_states) def test_add_oxidation_states_by_site(self): - self.struct.add_oxidation_state_by_site([2, -4]) + returned = self.struct.add_oxidation_state_by_site([2, -4]) + assert returned is self.struct assert self.struct[0].specie.oxi_state == 2 with pytest.raises( ValueError, match="Oxidation states of all sites must be specified, expected 2 values, got 1" @@ -1107,14 +1112,14 @@ def test_remove_oxidation_states(self): lattice = Lattice.cubic(10) struct_elem = Structure(lattice, [co_elem, o_elem], coords) struct_specie = Structure(lattice, [co_specie, o_specie], coords) - struct_out = struct_specie.remove_oxidation_states() - assert struct_out is struct_specie + returned = struct_specie.remove_oxidation_states() + assert returned is struct_specie assert struct_elem == struct_specie, "Oxidation state remover failed" - def test_add_oxidation_states_by_guess(self): + def test_add_oxidation_state_by_guess(self): struct = PymatgenTest.get_structure("Li2O") - struct_with_oxi = struct.add_oxidation_state_by_guess() - assert struct_with_oxi is struct + returned = struct.add_oxidation_state_by_guess() + assert returned is struct expected = [Species("Li", 1), Species("O", -2)] for site in struct: assert site.specie in expected @@ -1146,7 +1151,8 @@ def test_add_remove_spin_states(self): def test_apply_operation(self): op = SymmOp.from_axis_angle_and_translation([0, 0, 1], 90) struct = self.struct.copy() - struct.apply_operation(op) + returned = struct.apply_operation(op) + assert returned is struct assert_allclose( struct.lattice.matrix, [[0, 3.840198, 0], [-3.325710, 1.920099, 0], [2.217138, -0, 3.135509]], @@ -1161,7 +1167,8 @@ def test_apply_operation(self): def test_apply_strain(self): struct = self.struct initial_coord = struct[1].coords - struct.apply_strain(0.01) + returned = struct.apply_strain(0.01) + assert returned is struct assert approx(struct.lattice.abc) == (3.8785999130369997, 3.878600984287687, 3.8785999130549516) assert_allclose(struct[1].coords, initial_coord * 1.01) a1, b1, c1 = struct.lattice.abc @@ -1182,7 +1189,8 @@ def test_apply_strain(self): def test_scale_lattice(self): initial_coord = self.struct[1].coords - self.struct.scale_lattice(self.struct.volume * 1.01**3) + returned = self.struct.scale_lattice(self.struct.volume * 1.01**3) + assert returned is self.struct assert_allclose( self.struct.lattice.abc, (3.8785999130369997, 3.878600984287687, 3.8785999130549516), @@ -1190,7 +1198,8 @@ def test_scale_lattice(self): assert_allclose(self.struct[1].coords, initial_coord * 1.01) def test_translate_sites(self): - self.struct.translate_sites([0, 1], [0.5, 0.5, 0.5], frac_coords=True) + returned = self.struct.translate_sites([0, 1], [0.5, 0.5, 0.5], frac_coords=True) + assert returned is self.struct assert_allclose(self.struct.frac_coords[0], [0.5, 0.5, 0.5]) self.struct.translate_sites([0], [0.5, 0.5, 0.5], frac_coords=False) @@ -1214,12 +1223,13 @@ def test_translate_sites(self): assert self.struct == original_struct def test_rotate_sites(self): - self.struct.rotate_sites( + returned = self.struct.rotate_sites( indices=[1], theta=2.0 * np.pi / 3.0, anchor=self.struct[0].coords, to_unit_cell=False, ) + assert returned is self.struct assert_allclose(self.struct.frac_coords[1], [-1.25, 1.5, 0.75], atol=1e-6) self.struct.rotate_sites( indices=[1], @@ -2098,17 +2108,20 @@ def test_mutable_sequence_methods(self): def test_insert_remove_append(self): mol = self.mol - mol.insert(1, "O", [0.5, 0.5, 0.5]) + returned = mol.insert(1, "O", [0.5, 0.5, 0.5]) + assert returned is mol assert mol.formula == "H4 C1 O1" del mol[2] assert mol.formula == "H3 C1 O1" mol.set_charge_and_spin(0) assert mol.spin_multiplicity == 2 - mol.append("N", [1, 1, 1]) + returned = mol.append("N", [1, 1, 1]) + assert returned is mol assert mol.formula == "H3 C1 N1 O1" with pytest.raises(TypeError, match="unhashable type: 'Molecule'"): _ = {mol: 1} - mol.remove_sites([0, 1]) + returned = mol.remove_sites([0, 1]) + assert returned is mol assert mol.formula == "H3 N1" def test_from_sites(self): @@ -2119,18 +2132,21 @@ def test_from_sites(self): Molecule.from_sites([]) def test_translate_sites(self): - self.mol.translate_sites([0, 1], translation := (0.5, 0.5, 0.5)) + returned = self.mol.translate_sites([0, 1], translation := (0.5, 0.5, 0.5)) + assert returned is self.mol assert tuple(self.mol.cart_coords[0]) == translation def test_rotate_sites(self): - self.mol.rotate_sites(theta=np.radians(30)) + returned = self.mol.rotate_sites(theta=np.radians(30)) + assert returned is self.mol assert_allclose(self.mol.cart_coords[2], [0.889164737, 0.513359500, -0.363000000]) def test_replace(self): self.mol[0] = "Ge" assert self.mol.formula == "Ge1 H4" - self.mol.replace_species({Element("Ge"): {Element("Ge"): 0.5, Element("Si"): 0.5}}) + returned = self.mol.replace_species({Element("Ge"): {Element("Ge"): 0.5, Element("Si"): 0.5}}) + assert returned is self.mol assert self.mol.formula == "Si0.5 Ge0.5 H4" # this should change the .5Si .5Ge sites to .75Si .25Ge @@ -2145,15 +2161,17 @@ def test_replace(self): for idx, site in enumerate(pre_perturbation_sites): assert site.distance(post_perturbation_sites[idx]) == approx(dist), "Bad perturbation distance" - def test_add_site_property(self): - self.mol.add_site_property("charge", [4.1, -2, -2, -2, -2]) + def test_add_remove_site_property(self): + returned = self.mol.add_site_property("charge", [4.1, -2, -2, -2, -2]) + assert returned is self.mol assert self.mol[0].charge == 4.1 assert self.mol[1].charge == -2 self.mol.add_site_property("magmom", [3, 2, 2, 2, 2]) assert self.mol[0].charge == 4.1 assert self.mol[0].magmom == 3 - self.mol.remove_site_property("magmom") + returned = self.mol.remove_site_property("magmom") + assert returned is self.mol # test ValueError when values have wrong length with pytest.raises(ValueError, match=r"len\(values\)=2 must equal sites in structure=5"): @@ -2171,7 +2189,8 @@ def test_as_from_dict(self): def test_apply_operation(self): op = SymmOp.from_axis_angle_and_translation([0, 0, 1], 90) - self.mol.apply_operation(op) + returned = self.mol.apply_operation(op) + assert returned is self.mol assert_allclose(self.mol[2].coords, [0, 1.026719, -0.363000], atol=1e-12) def test_substitute(self): @@ -2183,7 +2202,8 @@ def test_substitute(self): [-0.513360, 0.889165, -0.363000], ] sub = Molecule(["X", "C", "H", "H", "H"], coords) - self.mol.substitute(1, sub) + returned = self.mol.substitute(1, sub) + assert returned is self.mol assert self.mol.get_distance(0, 4) == approx(1.54) f = Molecule(["X", "F"], [[0, 0, 0], [0, 0, 1.11]]) self.mol.substitute(2, f) @@ -2242,14 +2262,17 @@ def test_no_spin_check(self): coords = [[0, 0, 0], [0, 0, 1.089000], [1.026719, 0, -0.363000], [-0.513360, -0.889165, -0.363000]] expected_msg = "Charge of 0 and spin multiplicity of 1 is not possible for this molecule" with pytest.raises(ValueError, match=expected_msg): - mol = Molecule(["C", "H", "H", "H"], coords, charge=0, spin_multiplicity=1) + Molecule(["C", "H", "H", "H"], coords, charge=0, spin_multiplicity=1) mol_valid = Molecule(["C", "H", "H", "H"], coords, charge=0, spin_multiplicity=2) with pytest.raises(ValueError, match=expected_msg): mol_valid.set_charge_and_spin(0, 1) - mol = Molecule(["C", "H", "H", "H"], coords, charge=0, spin_multiplicity=1, charge_spin_check=False) + + def test_set_charge_and_spin(self): + mol = Molecule.from_dict(self.mol.as_dict() | dict(charge=0, spin_multiplicity=1, charge_spin_check=False)) assert mol.spin_multiplicity == 1 assert mol.charge == 0 - mol.set_charge_and_spin(0, 3) + returned = mol.set_charge_and_spin(0, 3) + assert returned is mol assert mol.charge == 0 assert mol.spin_multiplicity == 3 diff --git a/tests/core/test_surface.py b/tests/core/test_surface.py index a10bdd75cd8..0e3f194018b 100644 --- a/tests/core/test_surface.py +++ b/tests/core/test_surface.py @@ -109,7 +109,8 @@ def test_add_adsorbate_atom(self): 0, self.zno55.scale_factor, ) - zno_slab.add_adsorbate_atom([1], "H", 1) + returned = zno_slab.add_adsorbate_atom([1], "H", 1) + assert returned == zno_slab assert len(zno_slab) == 9 assert str(zno_slab[8].specie) == "H" diff --git a/tests/files/.pytest-split-durations b/tests/files/.pytest-split-durations index 07e22cd3a32..7f5492de099 100644 --- a/tests/files/.pytest-split-durations +++ b/tests/files/.pytest-split-durations @@ -1036,7 +1036,7 @@ "tests/core/test_structure.py::TestMolecule::test_translate_sites": 0.0019293749937787652, "tests/core/test_structure.py::TestNeighbor::test_msonable": 0.0021466249600052834, "tests/core/test_structure.py::TestNeighbor::test_neighbor_labels": 0.0014545419835485518, - "tests/core/test_structure.py::TestStructure::test_add_oxidation_states_by_element": 0.06694620900088921, + "tests/core/test_structure.py::TestStructure::test_add_oxidation_state_by_element": 0.06694620900088921, "tests/core/test_structure.py::TestStructure::test_add_oxidation_states_by_guess": 0.002574582991655916, "tests/core/test_structure.py::TestStructure::test_add_oxidation_states_by_site": 0.0022316250251606107, "tests/core/test_structure.py::TestStructure::test_add_remove_site_property": 0.002218457986600697,