Skip to content

Commit

Permalink
Merge pull request #318 from tovrstra/attr-attrs
Browse files Browse the repository at this point in the history
Migrate from attr to attrs
  • Loading branch information
tovrstra authored Jun 4, 2024
2 parents d82917a + 6047506 commit 9f7e987
Show file tree
Hide file tree
Showing 25 changed files with 199 additions and 183 deletions.
31 changes: 16 additions & 15 deletions iodata/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
from numbers import Integral
from typing import Union

import attr
import attrs
import numpy as np
from numpy.typing import NDArray

from .attrutils import validate_shape

Expand Down Expand Up @@ -100,7 +101,7 @@ def angmom_its(angmom: Union[int, list[int]]) -> Union[str, list[str]]:
return ANGMOM_CHARS[angmom]


@attr.s(auto_attribs=True, slots=True, on_setattr=[attr.setters.validate, attr.setters.convert])
@attrs.define
class Shell:
"""A shell in a molecular basis representing (generalized) contractions with the same exponents.
Expand All @@ -126,11 +127,11 @@ class Shell:
"""

icenter: int
angmoms: list[int] = attr.ib(validator=validate_shape(("coeffs", 1)))
kinds: list[str] = attr.ib(validator=validate_shape(("coeffs", 1)))
exponents: np.ndarray = attr.ib(validator=validate_shape(("coeffs", 0)))
coeffs: np.ndarray = attr.ib(validator=validate_shape(("exponents", 0), ("kinds", 0)))
icenter: int = attrs.field()
angmoms: list[int] = attrs.field(validator=validate_shape(("coeffs", 1)))
kinds: list[str] = attrs.field(validator=validate_shape(("coeffs", 1)))
exponents: NDArray = attrs.field(validator=validate_shape(("coeffs", 0)))
coeffs: NDArray = attrs.field(validator=validate_shape(("exponents", 0), ("kinds", 0)))

@property
def nbasis(self) -> int:
Expand All @@ -156,7 +157,7 @@ def ncon(self) -> int:
return len(self.angmoms)


@attr.s(auto_attribs=True, slots=True, on_setattr=[attr.setters.validate, attr.setters.convert])
@attrs.define
class MolecularBasis:
"""A complete molecular orbital or density basis set.
Expand Down Expand Up @@ -205,9 +206,9 @@ class MolecularBasis:
"""

shells: list[Shell]
conventions: dict[str, str]
primitive_normalization: str
shells: list[Shell] = attrs.field()
conventions: dict[str, str] = attrs.field()
primitive_normalization: str = attrs.field()

@property
def nbasis(self) -> int:
Expand All @@ -222,12 +223,12 @@ def get_segmented(self):
shells.append(
Shell(shell.icenter, [angmom], [kind], shell.exponents, coeffs.reshape(-1, 1))
)
return attr.evolve(self, shells=shells)
return attrs.evolve(self, shells=shells)


def convert_convention_shell(
conv1: list[str], conv2: list[str], reverse=False
) -> tuple[np.ndarray, np.ndarray]:
) -> tuple[NDArray, NDArray]:
"""Return a permutation vector and sign changes to convert from 1 to 2.
The transformation from convention 1 to convention 2 can be done applying
Expand Down Expand Up @@ -289,7 +290,7 @@ def convert_convention_shell(

def convert_conventions(
molbasis: MolecularBasis, new_conventions: dict[str, list[str]], reverse=False
) -> tuple[np.ndarray, np.ndarray]:
) -> tuple[NDArray, NDArray]:
"""Return a permutation vector and sign changes to convert from 1 to 2.
The transformation from molbasis.convention to the new convention can be done
Expand Down Expand Up @@ -339,7 +340,7 @@ def convert_conventions(
return np.array(permutation), np.array(signs)


def iter_cart_alphabet(n: int) -> np.ndarray:
def iter_cart_alphabet(n: int) -> NDArray:
"""Loop over powers of Cartesian basis functions in alphabetical order.
See https://theochem.github.io/horton/2.1.1/tech_ref_gaussian_basis.html
Expand Down
3 changes: 2 additions & 1 deletion iodata/formats/chgcar.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"""

import numpy as np
from numpy.typing import NDArray

from ..docstrings import document_load_one
from ..periodic import sym2num
Expand All @@ -37,7 +38,7 @@
PATTERNS = ["CHGCAR*", "AECCAR*"]


def _load_vasp_header(lit: LineIterator) -> tuple[str, np.ndarray, np.ndarray, np.ndarray]:
def _load_vasp_header(lit: LineIterator) -> tuple[str, NDArray, NDArray, NDArray]:
"""Load the cell and atoms from a VASP file format.
Parameters
Expand Down
15 changes: 7 additions & 8 deletions iodata/formats/cp2klog.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Union

import numpy as np
from numpy.typing import NDArray
from scipy.special import factorialk

from ..basis import HORTON2_CONVENTIONS, MolecularBasis, Shell, angmom_sti
Expand All @@ -42,9 +43,7 @@
}


def _get_cp2k_norm_corrections(
ell: int, alphas: Union[float, np.ndarray]
) -> Union[float, np.ndarray]:
def _get_cp2k_norm_corrections(ell: int, alphas: Union[float, NDArray]) -> Union[float, NDArray]:
"""Compute the corrections for the normalization of the basis functions.
This correction is needed because the CP2K atom code works with a different
Expand Down Expand Up @@ -236,7 +235,7 @@ def _read_cp2k_occupations_energies(

def _read_cp2k_orbital_coeffs(
lit: LineIterator, oe: list[tuple[int, int, float, float]]
) -> dict[tuple[int, int], np.ndarray]:
) -> dict[tuple[int, int], NDArray]:
"""Read the expansion coefficients of the orbital from an open CP2K ATOM output.
Parameters
Expand Down Expand Up @@ -294,11 +293,11 @@ def _get_norb_nel(oe: list[tuple[int, int, float, float]]) -> tuple[int, float]:


def _fill_orbitals(
orb_coeffs: np.ndarray,
orb_energies: np.ndarray,
orb_occupations: np.ndarray,
orb_coeffs: NDArray,
orb_energies: NDArray,
orb_occupations: NDArray,
oe: list[tuple[int, int, float, float]],
coeffs: dict[tuple[int, int], np.ndarray],
coeffs: dict[tuple[int, int], NDArray],
obasis: MolecularBasis,
restricted: bool,
):
Expand Down
19 changes: 10 additions & 9 deletions iodata/formats/cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from typing import TextIO

import numpy as np
from numpy.typing import NDArray

from ..docstrings import document_dump_one, document_load_one
from ..iodata import IOData
Expand All @@ -42,7 +43,7 @@

def _read_cube_header(
lit: LineIterator,
) -> tuple[str, np.ndarray, np.ndarray, np.ndarray, dict[str, np.ndarray], np.ndarray]:
) -> tuple[str, NDArray, NDArray, NDArray, dict[str, NDArray], NDArray]:
"""Load header data from a CUBE file object.
Parameters
Expand All @@ -62,7 +63,7 @@ def _read_cube_header(
# skip the second line
next(lit)

def read_grid_line(line: str) -> tuple[int, np.ndarray]:
def read_grid_line(line: str) -> tuple[int, NDArray]:
"""Read a grid line from the cube file."""
words = line.split()
return (
Expand All @@ -83,7 +84,7 @@ def read_grid_line(line: str) -> tuple[int, np.ndarray]:
cellvecs = axes * shape.reshape(-1, 1)
cube = {"origin": origin, "axes": axes, "shape": shape}

def read_atom_line(line: str) -> tuple[int, float, np.ndarray]:
def read_atom_line(line: str) -> tuple[int, float, NDArray]:
"""Read an atomic number and coordinate from the cube file."""
words = line.split()
return (
Expand All @@ -106,7 +107,7 @@ def read_atom_line(line: str) -> tuple[int, float, np.ndarray]:
return title, atcoords, atnums, cellvecs, cube, atcorenums


def _read_cube_data(lit: LineIterator, cube: dict[str, np.ndarray]):
def _read_cube_data(lit: LineIterator, cube: dict[str, NDArray]):
"""Load cube data from a CUBE file object.
Parameters
Expand Down Expand Up @@ -150,10 +151,10 @@ def load_one(lit: LineIterator) -> dict:
def _write_cube_header(
f: TextIO,
title: str,
atcoords: np.ndarray,
atnums: np.ndarray,
cube: dict[str, np.ndarray],
atcorenums: np.ndarray,
atcoords: NDArray,
atnums: NDArray,
cube: dict[str, NDArray],
atcorenums: NDArray,
):
print(title, file=f)
print("OUTER LOOP: X, MIDDLE LOOP: Y, INNER LOOP: Z", file=f)
Expand All @@ -169,7 +170,7 @@ def _write_cube_header(
print(f"{atnums[i]:5d} {q: 11.6f} {x: 11.6f} {y: 11.6f} {z: 11.6f}", file=f)


def _write_cube_data(f: TextIO, cube_data: np.ndarray, block_size: int):
def _write_cube_data(f: TextIO, cube_data: NDArray, block_size: int):
counter = 0
for value in cube_data.flat:
f.write(f" {value: 12.5E}")
Expand Down
7 changes: 4 additions & 3 deletions iodata/formats/fchk.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import Optional, TextIO

import numpy as np
from numpy.typing import NDArray

from ..basis import HORTON2_CONVENTIONS, MolecularBasis, Shell, convert_conventions
from ..docstrings import document_dump_one, document_load_many, document_load_one
Expand Down Expand Up @@ -473,7 +474,7 @@ def _load_dm(label: str, fchk: dict, result: dict, key: str):
result[key] = _triangle_to_dense(fchk[label])


def _triangle_to_dense(triangle: np.ndarray) -> np.ndarray:
def _triangle_to_dense(triangle: NDArray) -> NDArray:
"""Convert a symmetric matrix in triangular storage to a dense square matrix.
Parameters
Expand Down Expand Up @@ -512,7 +513,7 @@ def _dump_real_scalars(name: str, val: float, f: TextIO):
print(f"{name:40} R {float(val): 16.8E}", file=f)


def _dump_integer_arrays(name: str, val: np.ndarray, f: TextIO):
def _dump_integer_arrays(name: str, val: NDArray, f: TextIO):
"""Dumper for a array of integers."""
nval = val.size
if nval != 0:
Expand All @@ -527,7 +528,7 @@ def _dump_integer_arrays(name: str, val: np.ndarray, f: TextIO):
k = 0


def _dump_real_arrays(name: str, val: np.ndarray, f: TextIO):
def _dump_real_arrays(name: str, val: NDArray, f: TextIO):
"""Dumper for a array of float."""
nval = val.size
if nval != 0:
Expand Down
13 changes: 7 additions & 6 deletions iodata/formats/gamess.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""GAMESS punch file format."""

import numpy as np
from numpy.typing import NDArray

from ..docstrings import document_load_one
from ..utils import LineIterator, angstrom
Expand All @@ -29,7 +30,7 @@
PATTERNS = ["*.dat"]


def _read_data(lit: LineIterator) -> tuple:
def _read_data(lit: LineIterator) -> tuple[str, str, list[str]]:
"""Extract ``title``, ``symmetry`` and ``symbols`` from the punch file."""
title = next(lit).strip()
symmetry = next(lit).split()[0]
Expand All @@ -46,7 +47,7 @@ def _read_data(lit: LineIterator) -> tuple:
return title, symmetry, symbols


def _read_coordinates(lit: LineIterator, result: dict) -> tuple:
def _read_coordinates(lit: LineIterator, result: dict[str]) -> tuple[NDArray, NDArray]:
"""Extract ``numbers`` and ``coordinates`` from the punch file."""
for _ in range(2):
next(lit)
Expand All @@ -67,7 +68,7 @@ def _read_coordinates(lit: LineIterator, result: dict) -> tuple:
return numbers, coordinates


def _read_energy(lit: LineIterator, result: dict) -> tuple:
def _read_energy(lit: LineIterator, result: dict[str]) -> tuple[float, NDArray]:
"""Extract ``energy`` and ``gradient`` from the punch file."""
energy = float(next(lit).split()[1])
natom = len(result["symbols"])
Expand All @@ -81,7 +82,7 @@ def _read_energy(lit: LineIterator, result: dict) -> tuple:
return energy, gradient


def _read_hessian(lit: LineIterator, result: dict) -> np.ndarray:
def _read_hessian(lit: LineIterator, result: dict[str]) -> NDArray:
"""Extract ``hessian`` from the punch file."""
# check that $HESS is not already parsed
if "athessian" in result:
Expand All @@ -102,7 +103,7 @@ def _read_hessian(lit: LineIterator, result: dict) -> np.ndarray:
return hessian


def _read_masses(lit: LineIterator, result: dict) -> np.ndarray:
def _read_masses(lit: LineIterator, result: dict[str]) -> NDArray:
"""Extract ``masses`` from the punch file."""
natom = len(result["symbols"])
masses = np.zeros(natom, float)
Expand All @@ -119,7 +120,7 @@ def _read_masses(lit: LineIterator, result: dict) -> np.ndarray:
"PUNCH",
["title", "energy", "grot", "atgradient", "athessian", "atmasses", "atnums", "atcoords"],
)
def load_one(lit: LineIterator) -> dict:
def load_one(lit: LineIterator) -> dict[str]:
"""Do not edit this docstring. It will be overwritten."""
result = {}
while True:
Expand Down
5 changes: 3 additions & 2 deletions iodata/formats/gaussianlog.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"""

import numpy as np
from numpy.typing import NDArray

from ..docstrings import document_load_one
from ..utils import LineIterator, set_four_index_element
Expand Down Expand Up @@ -73,7 +74,7 @@ def load_one(lit: LineIterator) -> dict:
return result


def _load_twoindex_g09(lit: LineIterator, nbasis: int) -> np.ndarray:
def _load_twoindex_g09(lit: LineIterator, nbasis: int) -> NDArray:
"""Load a two-index operator from a GAUSSIAN LOG file format.
Parameters
Expand Down Expand Up @@ -106,7 +107,7 @@ def _load_twoindex_g09(lit: LineIterator, nbasis: int) -> np.ndarray:
return result


def _load_fourindex_g09(lit: LineIterator, nbasis: int) -> np.ndarray:
def _load_fourindex_g09(lit: LineIterator, nbasis: int) -> NDArray:
"""Load a four-index operator from a GAUSSIAN LOG file.
Parameters
Expand Down
7 changes: 3 additions & 4 deletions iodata/formats/mol2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from typing import TextIO

import numpy as np
from numpy.typing import NDArray

from ..docstrings import (
document_dump_many,
Expand Down Expand Up @@ -83,9 +84,7 @@ def load_one(lit: LineIterator) -> dict:
return result


def _load_helper_atoms(
lit: LineIterator, natoms: int
) -> tuple[np.ndarray, np.ndarray, np.ndarray, tuple]:
def _load_helper_atoms(lit: LineIterator, natoms: int) -> tuple[NDArray, NDArray, NDArray, tuple]:
"""Load element numbers, coordinates and atomic charges."""
atnums = np.empty(natoms)
atcoords = np.empty((natoms, 3))
Expand All @@ -112,7 +111,7 @@ def _load_helper_atoms(
return atnums, atcoords, atchgs, attypes


def _load_helper_bonds(lit: LineIterator, nbonds: int) -> tuple[np.ndarray]:
def _load_helper_bonds(lit: LineIterator, nbonds: int) -> NDArray:
"""Load bond information.
Each line in a bond definition has the following structure
Expand Down
Loading

0 comments on commit 9f7e987

Please sign in to comment.