Skip to content

Commit

Permalink
implemented io.pymatgen w/tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wolearyc committed Oct 2, 2024
1 parent 1f9ebbe commit 4984f65
Show file tree
Hide file tree
Showing 4 changed files with 307 additions and 5 deletions.
15 changes: 10 additions & 5 deletions ramannoodle/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_type_error(name: str, value: Any, correct_type: str) -> TypeError:
return TypeError(f"{name} should have type {correct_type}, not {wrong_type}")


def get_shape_error(name: str, array: NDArray, desired_shape: str) -> ValueError:
def get_shape_error(name: str, array: NDArray[Any], desired_shape: str) -> ValueError:
"""Return ValueError for an ndarray with the wrong shape.
:meta private:
Expand All @@ -72,7 +72,7 @@ def get_shape_error(name: str, array: NDArray, desired_shape: str) -> ValueError
return ValueError(f"{name} has wrong shape: {shape_spec}")


def verify_ndarray(name: str, array: NDArray) -> None:
def verify_ndarray(name: str, array: NDArray[Any]) -> None:
"""Verify type of NDArray .
:meta private: We should avoid calling this function wherever possible (EATF)
Expand All @@ -84,7 +84,7 @@ def verify_ndarray(name: str, array: NDArray) -> None:


def verify_ndarray_shape(
name: str, array: NDArray, shape: Sequence[int | None]
name: str, array: NDArray[Any], shape: Sequence[int | None]
) -> None:
"""Verify an NDArray's shape.
Expand Down Expand Up @@ -123,7 +123,7 @@ def verify_list_len(name: str, array: list[Any], length: int | None) -> None:
raise ValueError(f"{name} has wrong length: {len(array)} != length")


def verify_positions(name: str, array: NDArray) -> None:
def verify_positions(name: str, array: NDArray[Any]) -> None:
"""Verify fractional positions according to dimensions and boundary conditions.
:meta private:
Expand All @@ -136,5 +136,10 @@ def verify_positions(name: str, array: NDArray) -> None:

def get_torch_missing_error() -> UserError:
"""Get error indicating that torch is not installed."""
required_modules = "'torch', 'torch-scatter', and 'torch-sparse' modules"
required_modules = "'torch', 'torch-scatter', and 'torch-sparse' packages"
return UserError(f"torch functionality requires {required_modules}")


def get_pymatgen_missing_error() -> UserError:
"""Get error indicating that pymatgen is not installed."""
return UserError("pymatgen functionality requires pymatgen package")
18 changes: 18 additions & 0 deletions ramannoodle/io/pymatgen/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Functions for interacting with pymatgen."""

# flake8: noqa: F401
from ramannoodle.io.pymatgen.pymatgen import (
get_positions,
get_structure,
construct_polarizability_dataset,
construct_ref_structure,
construct_trajectory,
)

__all__ = [
"get_positions",
"get_structure",
"construct_polarizability_dataset",
"construct_ref_structure",
"construct_trajectory",
]
188 changes: 188 additions & 0 deletions ramannoodle/io/pymatgen/pymatgen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
"""Functions for interacting with pymatgen."""

import numpy as np
from numpy.typing import NDArray

from ramannoodle.exceptions import (
get_pymatgen_missing_error,
UserError,
verify_list_len,
verify_ndarray_shape,
get_type_error,
IncompatibleStructureException,
)
from ramannoodle.dynamics.trajectory import Trajectory
from ramannoodle.structure.reference import ReferenceStructure

try:
from ramannoodle.dataset.torch.dataset import PolarizabilityDataset
except UserError:
pass

try:
import pymatgen.core
import pymatgen.core.trajectory
except ImportError as exc:
raise get_pymatgen_missing_error() from exc


def _get_lattice(pymatgen_structure: pymatgen.core.Structure) -> NDArray[np.float64]:
"""Get lattice from a pymatgen Structure.
Parameters
----------
pymatgen_structure
Returns
-------
:
(Å) Array with shape (3,3).
"""
try:
return pymatgen_structure.lattice.matrix
except AttributeError as exc:
raise get_type_error(
"pymatgen_structure", pymatgen_structure, "pymatgen.core.Structure"
) from exc


def get_positions(pymatgen_structure: pymatgen.core.Structure) -> NDArray[np.float64]:
"""Read fractional positions from a pymatgen Structure.
Parameters
----------
pymatgen_structure
Returns
-------
:
(fractional) Array with shape (N,3) where N is the number of atoms.
"""
try:
return pymatgen_structure.frac_coords
except AttributeError as exc:
raise get_type_error(
"pymatgen_structure", pymatgen_structure, "pymatgen.core.Structure"
) from exc


def _get_atomic_numbers(pymatgen_structure: pymatgen.core.Structure) -> list[int]:
"""Get atomic numbers from a pymatgen Structure.
Parameters
----------
pymatgen_structure
Returns
-------
:
List of length N where N is the number of atoms.
"""
try:
return list(pymatgen_structure.atomic_numbers)
except AttributeError as exc:
raise get_type_error(
"pymatgen_structure", pymatgen_structure, "pymatgen.core.Structure"
) from exc


def get_structure(
pymatgen_structure: pymatgen.core.Structure,
) -> tuple[NDArray[np.float64], list[int], NDArray[np.float64]]:
"""Get lattice, positions, and atomic numbers from a pymatgen Structure.
Parameters
----------
pymatgen_structure
Returns
-------
:
0. lattice -- (Å) Array with shape (3,3).
1. atomic numbers -- List of length N where N is the number of atoms.
2. positions -- (fractional) Array with shape (N,3) where N is the number of
atoms.
"""
return (
_get_lattice(pymatgen_structure),
_get_atomic_numbers(pymatgen_structure),
get_positions(pymatgen_structure),
)


def construct_polarizability_dataset(
pymatgen_structures: list[pymatgen.core.Structure],
polarizabilities: NDArray[np.float64],
) -> "PolarizabilityDataset":
"""Create a PolarizabilityDataset from of pymatgen Structures and polarizabilities.
Parameters
----------
pymatgen_structures
List of length M.
polarizabilities
Array with shape (M,3,3).
"""
verify_list_len("pymatgen_structures", pymatgen_structures, None)
verify_ndarray_shape(
"polarizabilities", polarizabilities, (len(pymatgen_structures), 3, 3)
)
lattice, atomic_numbers, _ = get_structure(pymatgen_structures[0])
positions = np.zeros((len(pymatgen_structures), len(atomic_numbers), 3))
for i, pymatgen_structure in enumerate(pymatgen_structures):
if not np.allclose(_get_lattice(pymatgen_structure), lattice, atol=1e-5):
raise IncompatibleStructureException(
f"incompatible lattice: pymatgen_structures[{i}]"
)
if _get_atomic_numbers(pymatgen_structure) != atomic_numbers:
raise IncompatibleStructureException(
f"incompatible atomic numbers: pymatgen_structures[{i}]"
)
positions[i] = get_positions(pymatgen_structure)
return PolarizabilityDataset(lattice, atomic_numbers, positions, polarizabilities)


def construct_ref_structure(
pymatgen_structure: pymatgen.core.Structure,
) -> ReferenceStructure:
"""Create a ReferenceStructure from a pymatgen Structure.
Parameters
----------
pymatgen_structure
"""
return ReferenceStructure(
_get_atomic_numbers(pymatgen_structure),
_get_lattice(pymatgen_structure),
get_positions(pymatgen_structure),
)


def construct_trajectory(
pymatgen_trajectory: pymatgen.core.trajectory.Trajectory,
timestep: float,
) -> Trajectory:
"""Create a Trajectory from a pymatgen Trajectory.
Parameters
----------
pymatgen_trajectory
timestep
(fs)
"""
try:
if not pymatgen_trajectory.constant_lattice:
raise ValueError("pymatgen_trajectory must have a constant lattice")
except AttributeError as exc:
raise get_type_error(
"pymatgen_trajectory",
pymatgen_trajectory,
"pymatgen.core.trajectory.Trajectory",
) from exc
return Trajectory(pymatgen_trajectory.coords, timestep)
91 changes: 91 additions & 0 deletions test/tests/pymatgen/test_pymatgen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Tests for pymatgen IO functions."""

from pathlib import Path

import pytest


import numpy as np

import pymatgen.core
import pymatgen.core.trajectory
import pymatgen.io.vasp

import ramannoodle as rn
import ramannoodle.io.pymatgen as pymatgen_io

# pylint: disable=protected-access


@pytest.mark.parametrize("path_fixture", ["test/data/TiO2/POSCAR"])
def test_get_positions(path_fixture: Path) -> None:
"""Test get_positions (normal)."""
known_positions = rn.io.vasp.poscar.read_positions(path_fixture)

positions = pymatgen_io.get_positions(
pymatgen.core.Structure.from_file(path_fixture)
)
assert np.allclose(known_positions, positions)


@pytest.mark.parametrize("path_fixture", ["test/data/TiO2/POSCAR"])
def test_construct_ref_structure(path_fixture: Path) -> None:
"""Test construct_ref_structure (normal)."""
known_ref_structure = rn.io.vasp.poscar.read_ref_structure(path_fixture)

structure = pymatgen.core.Structure.from_file(path_fixture)
ref_structure = pymatgen_io.construct_ref_structure(structure)

assert ref_structure.atomic_numbers == known_ref_structure.atomic_numbers


@pytest.mark.parametrize("path_fixture", ["test/data/TiO2/POSCAR"])
def test_get_structure(path_fixture: Path) -> None:
"""Test get_structure (normal)."""
known_structure = rn.io.vasp.poscar.read_structure(path_fixture)

pymatgen_structure = pymatgen.core.Structure.from_file(path_fixture)
structure = pymatgen_io.get_structure(pymatgen_structure)

assert np.allclose(known_structure[0], structure[0])
assert known_structure[1] == structure[1]
assert np.allclose(known_structure[2], structure[2])


@pytest.mark.parametrize("path_fixture", ["test/data/STO/XDATCAR"])
def test_construct_trajectory(path_fixture: Path) -> None:
"""Test construct_trajectory (normal)."""
known_trajectory = rn.io.vasp.xdatcar.read_trajectory(path_fixture, 1.0)

pymatgen_trajectory = pymatgen.core.trajectory.Trajectory.from_file(path_fixture)
trajectory = pymatgen_io.construct_trajectory(pymatgen_trajectory, 1.0)

assert np.allclose(known_trajectory.positions_ts, trajectory.positions_ts)


@pytest.mark.parametrize(
"filepaths",
[
[
"test/data/STO/vasprun.xml",
],
],
)
def test_load_polarizability_dataset(filepaths: str | list[str]) -> None:
"""Test of construct_polarizability_dataset (normal)."""
known_dataset = rn.io.generic.read_polarizability_dataset(filepaths, "vasprun.xml")

structures = []
polarizabilities = []
for filepath in filepaths:
vasprun = pymatgen.io.vasp.outputs.Vasprun(filepath)
polarizabilities.append(np.array(np.array(vasprun.epsilon_static)))
pymatgen_structure = pymatgen.core.Structure.from_file(filepath)

structures.append(pymatgen_structure)

dataset = pymatgen_io.construct_polarizability_dataset(
structures, np.array(polarizabilities)
)

assert np.allclose(known_dataset.polarizabilities, dataset.polarizabilities)

0 comments on commit 4984f65

Please sign in to comment.