diff --git a/ramannoodle/exceptions.py b/ramannoodle/exceptions.py index e955f4c..a900de9 100644 --- a/ramannoodle/exceptions.py +++ b/ramannoodle/exceptions.py @@ -63,7 +63,7 @@ def get_type_error(name: str, value: Any, correct_type: str) -> TypeError: return TypeError(f"{name} should have type {correct_type}, not {wrong_type}") -def get_shape_error(name: str, array: NDArray, desired_shape: str) -> ValueError: +def get_shape_error(name: str, array: NDArray[Any], desired_shape: str) -> ValueError: """Return ValueError for an ndarray with the wrong shape. :meta private: @@ -72,7 +72,7 @@ def get_shape_error(name: str, array: NDArray, desired_shape: str) -> ValueError return ValueError(f"{name} has wrong shape: {shape_spec}") -def verify_ndarray(name: str, array: NDArray) -> None: +def verify_ndarray(name: str, array: NDArray[Any]) -> None: """Verify type of NDArray . :meta private: We should avoid calling this function wherever possible (EATF) @@ -84,7 +84,7 @@ def verify_ndarray(name: str, array: NDArray) -> None: def verify_ndarray_shape( - name: str, array: NDArray, shape: Sequence[int | None] + name: str, array: NDArray[Any], shape: Sequence[int | None] ) -> None: """Verify an NDArray's shape. @@ -123,7 +123,7 @@ def verify_list_len(name: str, array: list[Any], length: int | None) -> None: raise ValueError(f"{name} has wrong length: {len(array)} != length") -def verify_positions(name: str, array: NDArray) -> None: +def verify_positions(name: str, array: NDArray[Any]) -> None: """Verify fractional positions according to dimensions and boundary conditions. :meta private: @@ -136,5 +136,10 @@ def verify_positions(name: str, array: NDArray) -> None: def get_torch_missing_error() -> UserError: """Get error indicating that torch is not installed.""" - required_modules = "'torch', 'torch-scatter', and 'torch-sparse' modules" + required_modules = "'torch', 'torch-scatter', and 'torch-sparse' packages" return UserError(f"torch functionality requires {required_modules}") + + +def get_pymatgen_missing_error() -> UserError: + """Get error indicating that pymatgen is not installed.""" + return UserError("pymatgen functionality requires pymatgen package") diff --git a/ramannoodle/io/pymatgen/__init__.py b/ramannoodle/io/pymatgen/__init__.py new file mode 100644 index 0000000..5d078bd --- /dev/null +++ b/ramannoodle/io/pymatgen/__init__.py @@ -0,0 +1,18 @@ +"""Functions for interacting with pymatgen.""" + +# flake8: noqa: F401 +from ramannoodle.io.pymatgen.pymatgen import ( + get_positions, + get_structure, + construct_polarizability_dataset, + construct_ref_structure, + construct_trajectory, +) + +__all__ = [ + "get_positions", + "get_structure", + "construct_polarizability_dataset", + "construct_ref_structure", + "construct_trajectory", +] diff --git a/ramannoodle/io/pymatgen/pymatgen.py b/ramannoodle/io/pymatgen/pymatgen.py new file mode 100644 index 0000000..6f1fa79 --- /dev/null +++ b/ramannoodle/io/pymatgen/pymatgen.py @@ -0,0 +1,188 @@ +"""Functions for interacting with pymatgen.""" + +import numpy as np +from numpy.typing import NDArray + +from ramannoodle.exceptions import ( + get_pymatgen_missing_error, + UserError, + verify_list_len, + verify_ndarray_shape, + get_type_error, + IncompatibleStructureException, +) +from ramannoodle.dynamics.trajectory import Trajectory +from ramannoodle.structure.reference import ReferenceStructure + +try: + from ramannoodle.dataset.torch.dataset import PolarizabilityDataset +except UserError: + pass + +try: + import pymatgen.core + import pymatgen.core.trajectory +except ImportError as exc: + raise get_pymatgen_missing_error() from exc + + +def _get_lattice(pymatgen_structure: pymatgen.core.Structure) -> NDArray[np.float64]: + """Get lattice from a pymatgen Structure. + + Parameters + ---------- + pymatgen_structure + + Returns + ------- + : + (Å) Array with shape (3,3). + + """ + try: + return pymatgen_structure.lattice.matrix + except AttributeError as exc: + raise get_type_error( + "pymatgen_structure", pymatgen_structure, "pymatgen.core.Structure" + ) from exc + + +def get_positions(pymatgen_structure: pymatgen.core.Structure) -> NDArray[np.float64]: + """Read fractional positions from a pymatgen Structure. + + Parameters + ---------- + pymatgen_structure + + Returns + ------- + : + (fractional) Array with shape (N,3) where N is the number of atoms. + + """ + try: + return pymatgen_structure.frac_coords + except AttributeError as exc: + raise get_type_error( + "pymatgen_structure", pymatgen_structure, "pymatgen.core.Structure" + ) from exc + + +def _get_atomic_numbers(pymatgen_structure: pymatgen.core.Structure) -> list[int]: + """Get atomic numbers from a pymatgen Structure. + + Parameters + ---------- + pymatgen_structure + + Returns + ------- + : + List of length N where N is the number of atoms. + + """ + try: + return list(pymatgen_structure.atomic_numbers) + except AttributeError as exc: + raise get_type_error( + "pymatgen_structure", pymatgen_structure, "pymatgen.core.Structure" + ) from exc + + +def get_structure( + pymatgen_structure: pymatgen.core.Structure, +) -> tuple[NDArray[np.float64], list[int], NDArray[np.float64]]: + """Get lattice, positions, and atomic numbers from a pymatgen Structure. + + Parameters + ---------- + pymatgen_structure + + Returns + ------- + : + 0. lattice -- (Å) Array with shape (3,3). + 1. atomic numbers -- List of length N where N is the number of atoms. + 2. positions -- (fractional) Array with shape (N,3) where N is the number of + atoms. + """ + return ( + _get_lattice(pymatgen_structure), + _get_atomic_numbers(pymatgen_structure), + get_positions(pymatgen_structure), + ) + + +def construct_polarizability_dataset( + pymatgen_structures: list[pymatgen.core.Structure], + polarizabilities: NDArray[np.float64], +) -> "PolarizabilityDataset": + """Create a PolarizabilityDataset from of pymatgen Structures and polarizabilities. + + Parameters + ---------- + pymatgen_structures + List of length M. + polarizabilities + Array with shape (M,3,3). + + """ + verify_list_len("pymatgen_structures", pymatgen_structures, None) + verify_ndarray_shape( + "polarizabilities", polarizabilities, (len(pymatgen_structures), 3, 3) + ) + lattice, atomic_numbers, _ = get_structure(pymatgen_structures[0]) + positions = np.zeros((len(pymatgen_structures), len(atomic_numbers), 3)) + for i, pymatgen_structure in enumerate(pymatgen_structures): + if not np.allclose(_get_lattice(pymatgen_structure), lattice, atol=1e-5): + raise IncompatibleStructureException( + f"incompatible lattice: pymatgen_structures[{i}]" + ) + if _get_atomic_numbers(pymatgen_structure) != atomic_numbers: + raise IncompatibleStructureException( + f"incompatible atomic numbers: pymatgen_structures[{i}]" + ) + positions[i] = get_positions(pymatgen_structure) + return PolarizabilityDataset(lattice, atomic_numbers, positions, polarizabilities) + + +def construct_ref_structure( + pymatgen_structure: pymatgen.core.Structure, +) -> ReferenceStructure: + """Create a ReferenceStructure from a pymatgen Structure. + + Parameters + ---------- + pymatgen_structure + + """ + return ReferenceStructure( + _get_atomic_numbers(pymatgen_structure), + _get_lattice(pymatgen_structure), + get_positions(pymatgen_structure), + ) + + +def construct_trajectory( + pymatgen_trajectory: pymatgen.core.trajectory.Trajectory, + timestep: float, +) -> Trajectory: + """Create a Trajectory from a pymatgen Trajectory. + + Parameters + ---------- + pymatgen_trajectory + timestep + (fs) + + """ + try: + if not pymatgen_trajectory.constant_lattice: + raise ValueError("pymatgen_trajectory must have a constant lattice") + except AttributeError as exc: + raise get_type_error( + "pymatgen_trajectory", + pymatgen_trajectory, + "pymatgen.core.trajectory.Trajectory", + ) from exc + return Trajectory(pymatgen_trajectory.coords, timestep) diff --git a/test/tests/pymatgen/test_pymatgen.py b/test/tests/pymatgen/test_pymatgen.py new file mode 100644 index 0000000..5cb488e --- /dev/null +++ b/test/tests/pymatgen/test_pymatgen.py @@ -0,0 +1,91 @@ +"""Tests for pymatgen IO functions.""" + +from pathlib import Path + +import pytest + + +import numpy as np + +import pymatgen.core +import pymatgen.core.trajectory +import pymatgen.io.vasp + +import ramannoodle as rn +import ramannoodle.io.pymatgen as pymatgen_io + +# pylint: disable=protected-access + + +@pytest.mark.parametrize("path_fixture", ["test/data/TiO2/POSCAR"]) +def test_get_positions(path_fixture: Path) -> None: + """Test get_positions (normal).""" + known_positions = rn.io.vasp.poscar.read_positions(path_fixture) + + positions = pymatgen_io.get_positions( + pymatgen.core.Structure.from_file(path_fixture) + ) + assert np.allclose(known_positions, positions) + + +@pytest.mark.parametrize("path_fixture", ["test/data/TiO2/POSCAR"]) +def test_construct_ref_structure(path_fixture: Path) -> None: + """Test construct_ref_structure (normal).""" + known_ref_structure = rn.io.vasp.poscar.read_ref_structure(path_fixture) + + structure = pymatgen.core.Structure.from_file(path_fixture) + ref_structure = pymatgen_io.construct_ref_structure(structure) + + assert ref_structure.atomic_numbers == known_ref_structure.atomic_numbers + + +@pytest.mark.parametrize("path_fixture", ["test/data/TiO2/POSCAR"]) +def test_get_structure(path_fixture: Path) -> None: + """Test get_structure (normal).""" + known_structure = rn.io.vasp.poscar.read_structure(path_fixture) + + pymatgen_structure = pymatgen.core.Structure.from_file(path_fixture) + structure = pymatgen_io.get_structure(pymatgen_structure) + + assert np.allclose(known_structure[0], structure[0]) + assert known_structure[1] == structure[1] + assert np.allclose(known_structure[2], structure[2]) + + +@pytest.mark.parametrize("path_fixture", ["test/data/STO/XDATCAR"]) +def test_construct_trajectory(path_fixture: Path) -> None: + """Test construct_trajectory (normal).""" + known_trajectory = rn.io.vasp.xdatcar.read_trajectory(path_fixture, 1.0) + + pymatgen_trajectory = pymatgen.core.trajectory.Trajectory.from_file(path_fixture) + trajectory = pymatgen_io.construct_trajectory(pymatgen_trajectory, 1.0) + + assert np.allclose(known_trajectory.positions_ts, trajectory.positions_ts) + + +@pytest.mark.parametrize( + "filepaths", + [ + [ + "test/data/STO/vasprun.xml", + ], + ], +) +def test_load_polarizability_dataset(filepaths: str | list[str]) -> None: + """Test of construct_polarizability_dataset (normal).""" + known_dataset = rn.io.generic.read_polarizability_dataset(filepaths, "vasprun.xml") + + structures = [] + polarizabilities = [] + for filepath in filepaths: + vasprun = pymatgen.io.vasp.outputs.Vasprun(filepath) + polarizabilities.append(np.array(np.array(vasprun.epsilon_static))) + pymatgen_structure = pymatgen.core.Structure.from_file(filepath) + + structures.append(pymatgen_structure) + + dataset = pymatgen_io.construct_polarizability_dataset( + structures, np.array(polarizabilities) + ) + + assert np.allclose(known_dataset.polarizabilities, dataset.polarizabilities)