-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
307 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
"""Functions for interacting with pymatgen.""" | ||
|
||
# flake8: noqa: F401 | ||
from ramannoodle.io.pymatgen.pymatgen import ( | ||
get_positions, | ||
get_structure, | ||
construct_polarizability_dataset, | ||
construct_ref_structure, | ||
construct_trajectory, | ||
) | ||
|
||
__all__ = [ | ||
"get_positions", | ||
"get_structure", | ||
"construct_polarizability_dataset", | ||
"construct_ref_structure", | ||
"construct_trajectory", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
"""Functions for interacting with pymatgen.""" | ||
|
||
import numpy as np | ||
from numpy.typing import NDArray | ||
|
||
from ramannoodle.exceptions import ( | ||
get_pymatgen_missing_error, | ||
UserError, | ||
verify_list_len, | ||
verify_ndarray_shape, | ||
get_type_error, | ||
IncompatibleStructureException, | ||
) | ||
from ramannoodle.dynamics.trajectory import Trajectory | ||
from ramannoodle.structure.reference import ReferenceStructure | ||
|
||
try: | ||
from ramannoodle.dataset.torch.dataset import PolarizabilityDataset | ||
except UserError: | ||
pass | ||
|
||
try: | ||
import pymatgen.core | ||
import pymatgen.core.trajectory | ||
except ImportError as exc: | ||
raise get_pymatgen_missing_error() from exc | ||
|
||
|
||
def _get_lattice(pymatgen_structure: pymatgen.core.Structure) -> NDArray[np.float64]: | ||
"""Get lattice from a pymatgen Structure. | ||
Parameters | ||
---------- | ||
pymatgen_structure | ||
Returns | ||
------- | ||
: | ||
(Å) Array with shape (3,3). | ||
""" | ||
try: | ||
return pymatgen_structure.lattice.matrix | ||
except AttributeError as exc: | ||
raise get_type_error( | ||
"pymatgen_structure", pymatgen_structure, "pymatgen.core.Structure" | ||
) from exc | ||
|
||
|
||
def get_positions(pymatgen_structure: pymatgen.core.Structure) -> NDArray[np.float64]: | ||
"""Read fractional positions from a pymatgen Structure. | ||
Parameters | ||
---------- | ||
pymatgen_structure | ||
Returns | ||
------- | ||
: | ||
(fractional) Array with shape (N,3) where N is the number of atoms. | ||
""" | ||
try: | ||
return pymatgen_structure.frac_coords | ||
except AttributeError as exc: | ||
raise get_type_error( | ||
"pymatgen_structure", pymatgen_structure, "pymatgen.core.Structure" | ||
) from exc | ||
|
||
|
||
def _get_atomic_numbers(pymatgen_structure: pymatgen.core.Structure) -> list[int]: | ||
"""Get atomic numbers from a pymatgen Structure. | ||
Parameters | ||
---------- | ||
pymatgen_structure | ||
Returns | ||
------- | ||
: | ||
List of length N where N is the number of atoms. | ||
""" | ||
try: | ||
return list(pymatgen_structure.atomic_numbers) | ||
except AttributeError as exc: | ||
raise get_type_error( | ||
"pymatgen_structure", pymatgen_structure, "pymatgen.core.Structure" | ||
) from exc | ||
|
||
|
||
def get_structure( | ||
pymatgen_structure: pymatgen.core.Structure, | ||
) -> tuple[NDArray[np.float64], list[int], NDArray[np.float64]]: | ||
"""Get lattice, positions, and atomic numbers from a pymatgen Structure. | ||
Parameters | ||
---------- | ||
pymatgen_structure | ||
Returns | ||
------- | ||
: | ||
0. lattice -- (Å) Array with shape (3,3). | ||
1. atomic numbers -- List of length N where N is the number of atoms. | ||
2. positions -- (fractional) Array with shape (N,3) where N is the number of | ||
atoms. | ||
""" | ||
return ( | ||
_get_lattice(pymatgen_structure), | ||
_get_atomic_numbers(pymatgen_structure), | ||
get_positions(pymatgen_structure), | ||
) | ||
|
||
|
||
def construct_polarizability_dataset( | ||
pymatgen_structures: list[pymatgen.core.Structure], | ||
polarizabilities: NDArray[np.float64], | ||
) -> "PolarizabilityDataset": | ||
"""Create a PolarizabilityDataset from of pymatgen Structures and polarizabilities. | ||
Parameters | ||
---------- | ||
pymatgen_structures | ||
List of length M. | ||
polarizabilities | ||
Array with shape (M,3,3). | ||
""" | ||
verify_list_len("pymatgen_structures", pymatgen_structures, None) | ||
verify_ndarray_shape( | ||
"polarizabilities", polarizabilities, (len(pymatgen_structures), 3, 3) | ||
) | ||
lattice, atomic_numbers, _ = get_structure(pymatgen_structures[0]) | ||
positions = np.zeros((len(pymatgen_structures), len(atomic_numbers), 3)) | ||
for i, pymatgen_structure in enumerate(pymatgen_structures): | ||
if not np.allclose(_get_lattice(pymatgen_structure), lattice, atol=1e-5): | ||
raise IncompatibleStructureException( | ||
f"incompatible lattice: pymatgen_structures[{i}]" | ||
) | ||
if _get_atomic_numbers(pymatgen_structure) != atomic_numbers: | ||
raise IncompatibleStructureException( | ||
f"incompatible atomic numbers: pymatgen_structures[{i}]" | ||
) | ||
positions[i] = get_positions(pymatgen_structure) | ||
return PolarizabilityDataset(lattice, atomic_numbers, positions, polarizabilities) | ||
|
||
|
||
def construct_ref_structure( | ||
pymatgen_structure: pymatgen.core.Structure, | ||
) -> ReferenceStructure: | ||
"""Create a ReferenceStructure from a pymatgen Structure. | ||
Parameters | ||
---------- | ||
pymatgen_structure | ||
""" | ||
return ReferenceStructure( | ||
_get_atomic_numbers(pymatgen_structure), | ||
_get_lattice(pymatgen_structure), | ||
get_positions(pymatgen_structure), | ||
) | ||
|
||
|
||
def construct_trajectory( | ||
pymatgen_trajectory: pymatgen.core.trajectory.Trajectory, | ||
timestep: float, | ||
) -> Trajectory: | ||
"""Create a Trajectory from a pymatgen Trajectory. | ||
Parameters | ||
---------- | ||
pymatgen_trajectory | ||
timestep | ||
(fs) | ||
""" | ||
try: | ||
if not pymatgen_trajectory.constant_lattice: | ||
raise ValueError("pymatgen_trajectory must have a constant lattice") | ||
except AttributeError as exc: | ||
raise get_type_error( | ||
"pymatgen_trajectory", | ||
pymatgen_trajectory, | ||
"pymatgen.core.trajectory.Trajectory", | ||
) from exc | ||
return Trajectory(pymatgen_trajectory.coords, timestep) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
"""Tests for pymatgen IO functions.""" | ||
|
||
from pathlib import Path | ||
|
||
import pytest | ||
|
||
|
||
import numpy as np | ||
|
||
import pymatgen.core | ||
import pymatgen.core.trajectory | ||
import pymatgen.io.vasp | ||
|
||
import ramannoodle as rn | ||
import ramannoodle.io.pymatgen as pymatgen_io | ||
|
||
# pylint: disable=protected-access | ||
|
||
|
||
@pytest.mark.parametrize("path_fixture", ["test/data/TiO2/POSCAR"]) | ||
def test_get_positions(path_fixture: Path) -> None: | ||
"""Test get_positions (normal).""" | ||
known_positions = rn.io.vasp.poscar.read_positions(path_fixture) | ||
|
||
positions = pymatgen_io.get_positions( | ||
pymatgen.core.Structure.from_file(path_fixture) | ||
) | ||
assert np.allclose(known_positions, positions) | ||
|
||
|
||
@pytest.mark.parametrize("path_fixture", ["test/data/TiO2/POSCAR"]) | ||
def test_construct_ref_structure(path_fixture: Path) -> None: | ||
"""Test construct_ref_structure (normal).""" | ||
known_ref_structure = rn.io.vasp.poscar.read_ref_structure(path_fixture) | ||
|
||
structure = pymatgen.core.Structure.from_file(path_fixture) | ||
ref_structure = pymatgen_io.construct_ref_structure(structure) | ||
|
||
assert ref_structure.atomic_numbers == known_ref_structure.atomic_numbers | ||
|
||
|
||
@pytest.mark.parametrize("path_fixture", ["test/data/TiO2/POSCAR"]) | ||
def test_get_structure(path_fixture: Path) -> None: | ||
"""Test get_structure (normal).""" | ||
known_structure = rn.io.vasp.poscar.read_structure(path_fixture) | ||
|
||
pymatgen_structure = pymatgen.core.Structure.from_file(path_fixture) | ||
structure = pymatgen_io.get_structure(pymatgen_structure) | ||
|
||
assert np.allclose(known_structure[0], structure[0]) | ||
assert known_structure[1] == structure[1] | ||
assert np.allclose(known_structure[2], structure[2]) | ||
|
||
|
||
@pytest.mark.parametrize("path_fixture", ["test/data/STO/XDATCAR"]) | ||
def test_construct_trajectory(path_fixture: Path) -> None: | ||
"""Test construct_trajectory (normal).""" | ||
known_trajectory = rn.io.vasp.xdatcar.read_trajectory(path_fixture, 1.0) | ||
|
||
pymatgen_trajectory = pymatgen.core.trajectory.Trajectory.from_file(path_fixture) | ||
trajectory = pymatgen_io.construct_trajectory(pymatgen_trajectory, 1.0) | ||
|
||
assert np.allclose(known_trajectory.positions_ts, trajectory.positions_ts) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"filepaths", | ||
[ | ||
[ | ||
"test/data/STO/vasprun.xml", | ||
], | ||
], | ||
) | ||
def test_load_polarizability_dataset(filepaths: str | list[str]) -> None: | ||
"""Test of construct_polarizability_dataset (normal).""" | ||
known_dataset = rn.io.generic.read_polarizability_dataset(filepaths, "vasprun.xml") | ||
|
||
structures = [] | ||
polarizabilities = [] | ||
for filepath in filepaths: | ||
vasprun = pymatgen.io.vasp.outputs.Vasprun(filepath) | ||
polarizabilities.append(np.array(np.array(vasprun.epsilon_static))) | ||
pymatgen_structure = pymatgen.core.Structure.from_file(filepath) | ||
|
||
structures.append(pymatgen_structure) | ||
|
||
dataset = pymatgen_io.construct_polarizability_dataset( | ||
structures, np.array(polarizabilities) | ||
) | ||
|
||
assert np.allclose(known_dataset.polarizabilities, dataset.polarizabilities) |