Skip to content

Commit

Permalink
Turn all NamedTuple classes into attr classes
Browse files Browse the repository at this point in the history
This introduces additional attribute validation, which fixes some
silly bugs.
  • Loading branch information
tovrstra committed Aug 27, 2020
1 parent 8e33fd8 commit db297ac
Show file tree
Hide file tree
Showing 17 changed files with 497 additions and 100 deletions.
35 changes: 35 additions & 0 deletions CONTRIBUTING.rst
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,42 @@ to avoid duplicate efforts.
results in minor corrections at worst. We'll do our best to avoid larger
problems in step 1.


Notes on attrs
--------------

IOData uses the `attrs`_ library, not to be confused with the `attr`_ library,
for classes representing data loaded from files: ``IOData``, ``MolecularBasis``,
``Shell``, ``MolecularOrbitals`` and ``Cube``. This enables basic attribute
validation, which eliminates potentially silly bugs. The following two tricks
might be convenient with working with these classes:

- The data can be turned into plain Python data types with the ``attr.asdict``
function. Make sure you add the ``retain_collection_types=True`` option, to
avoid the following issue: https://github.com/python-attrs/attrs/issues/646
For example.

.. code-block:: python
from iodata import load_one
import attr
iodata = load_one("example.xyz")
fields = attr.asdict(iodata, retain_collection_types=True)
- A shallow copy with a few modified attributes can be created with the evolve
method, which is a wrapper for ``attr.evolve``:

.. code-block:: python
from iodata import load_one
import attr
iodata1 = load_one("example.xyz")
iodata2 = iodata1.evolve(title="another title")
.. _Bash: https://en.wikipedia.org/wiki/Bash_(Unix_shell)
.. _Python: https://en.wikipedia.org/wiki/Python_(programming_language)
.. _type hinting: https://docs.python.org/3/library/typing.html
.. _PEP 0563: https://www.python.org/dev/peps/pep-0563/
.. _attrs: https://www.attrs.org/en/stable/
.. _attr: https://github.com/denis-ryzhkov/attr
129 changes: 129 additions & 0 deletions iodata/attrutils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# IODATA is an input and output module for quantum chemistry.
# Copyright (C) 2011-2019 The IODATA Development Team
#
# This file is part of IODATA.
#
# IODATA is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 3
# of the License, or (at your option) any later version.
#
# IODATA is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, see <http://www.gnu.org/licenses/>
# --
"""Utilities for building attr classes."""


import numpy as np


__all__ = ["convert_array_to", "validate_shape"]


def convert_array_to(dtype):
"""Return a function to convert arrays to the given type."""
def converter(array):
if array is None:
return None
return np.array(array, copy=False, dtype=dtype)
return converter


# pylint: disable=too-many-branches
def validate_shape(*shape_requirements: tuple):
"""Return a validator for the shape of an array or the length of an iterable.
Parameters
----------
shape_requirements
Specifications for the required shape. Every item of the tuple describes
the required size of the corresponding axis of an array. Also the
number of items should match the dimensionality of the array. When the
validator is used for general iterables, this tuple should contain just
one element. Possible values for each item are explained in the "Notes"
section below.
Returns
-------
validator
A validator function for the attr library.
Notes
-----
Every element of ``shape_requirements`` defines the expected size of an
array along the corresponding axis. An item in this tuple at position (or
index) ``i`` can be one of the following:
1. An integer, which is taken as the expected size along axis ``i``.
2. None. In this case, the size of the array along axis ``i`` is not
checked.
3. A string, which should be the name of another integer attribute with
the expected size along axis ``i``. The other attribute is always an
attribute of the same object as the attribute being checked.
4. A 2-tuple containing a name and an integer. In this case, the name refers
to another attribute which is an array or an iterable. When the integer
is 0, just the length of the other attribute is used. When the integer is
non-zero, the other attribute must be an array and the integer selects an
axis. The size of the other array along the selected axis is then used as
the expected size of the array being checked along axis ``i``.
"""
def validator(obj, attribute, value):
# Build the expected shape, with the rules from the docstring.
expected_shape = []
for item in shape_requirements:
if isinstance(item, int) or item is None:
expected_shape.append(item)
elif isinstance(item, str):
expected_shape.append(getattr(obj, item))
elif isinstance(item, tuple) and len(item) == 2:
other_name, other_axis = item
other = getattr(obj, other_name)
if other is None:
raise TypeError(
"Other attribute '{}' is not set.".format(other_name)
)
if other_axis == 0:
expected_shape.append(len(other))
else:
if other_axis >= other.ndim or other_axis < 0:
raise TypeError(
"Cannot get length along axis "
"{} of attribute {} with ndim {}.".format(
other_axis, other_name, other.ndim
)
)
expected_shape.append(other.shape[other_axis])
else:
raise ValueError(f"Cannot interpret item in shape_requirements: {item}")
expected_shape = tuple(expected_shape)
# Get the actual shape
if isinstance(value, np.ndarray):
observed_shape = value.shape
else:
observed_shape = (len(value),)
# Compare
match = True
if len(expected_shape) != len(observed_shape):
match = False
if match:
for es, os in zip(expected_shape, observed_shape):
if es is None:
continue
if es != os:
match = False
break
# Raise TypeError if needed.
if not match:
raise TypeError(
"Expecting shape {} for attribute {}, got {}".format(
expected_shape, attribute.name, observed_shape
)
)

return validator
32 changes: 23 additions & 9 deletions iodata/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@

from functools import wraps
from numbers import Integral
from typing import List, Dict, NamedTuple, Tuple, Union
from typing import List, Dict, Tuple, Union

import attr
import numpy as np

from .attrutils import validate_shape


__all__ = ['angmom_sti', 'angmom_its', 'Shell', 'MolecularBasis',
'convert_convention_shell', 'convert_conventions',
'iter_cart_alphabet', 'HORTON2_CONVENTIONS', 'PSI4_CONVENTIONS']
Expand Down Expand Up @@ -81,7 +85,8 @@ def angmom_its(angmom: Union[int, List[int]]) -> Union[str, List[str]]:
return ANGMOM_CHARS[angmom]


class Shell(NamedTuple):
@attr.s(auto_attribs=True, slots=True)
class Shell:
"""A shell in a molecular basis representing (generalized) contractions with the same exponents.
Attributes
Expand All @@ -107,10 +112,10 @@ class Shell(NamedTuple):
"""

icenter: int
angmoms: List[int]
kinds: List[str]
exponents: np.ndarray
coeffs: np.ndarray
angmoms: List[int] = attr.ib(validator=validate_shape(("coeffs", 1)))
kinds: List[str] = attr.ib(validator=validate_shape(("coeffs", 1)))
exponents: np.ndarray = attr.ib(validator=validate_shape(("coeffs", 0)))
coeffs: np.ndarray = attr.ib(validator=validate_shape(("exponents", 0), ("kinds", 0)))

@property
def nbasis(self) -> int: # noqa: D401
Expand All @@ -135,8 +140,13 @@ def ncon(self) -> int: # noqa: D401
"""Number of contractions. This is usually 1; e.g., it would be 2 for an SP shell."""
return len(self.angmoms)

def evolve(self, **changes):
"""Create a copy with update attributes given in ``changes``."""
return attr.evolve(self, **changes)

class MolecularBasis(NamedTuple):

@attr.s(auto_attribs=True, slots=True)
class MolecularBasis:
"""A complete molecular orbital or density basis set.
Attributes
Expand Down Expand Up @@ -184,7 +194,7 @@ class MolecularBasis(NamedTuple):
"""

shells: tuple
shells: List[Shell]
conventions: Dict[str, str]
primitive_normalization: str

Expand All @@ -193,6 +203,10 @@ def nbasis(self) -> int: # noqa: D401
"""Number of basis functions."""
return sum(shell.nbasis for shell in self.shells)

def evolve(self, **changes):
"""Create a copy with update attributes given in ``changes``."""
return attr.evolve(self, **changes)

def get_segmented(self):
"""Unroll generalized contractions."""
shells = []
Expand All @@ -201,7 +215,7 @@ def get_segmented(self):
shells.append(Shell(shell.icenter, [angmom], [kind],
shell.exponents, coeffs.reshape(-1, 1)))
# pylint: disable=no-member
return self._replace(shells=shells)
return self.evolve(shells=shells)


def convert_convention_shell(conv1: List[str], conv2: List[str], reverse=False) \
Expand Down
2 changes: 1 addition & 1 deletion iodata/formats/chgcar.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def _load_vasp_grid(lit: LineIterator) -> dict:
cube_data[i0, i1, i2] = float(words.pop(0))

cube = Cube(origin=np.zeros(3), axes=cellvecs / shape.reshape(-1, 1),
shape=shape, data=cube_data)
data=cube_data)

return {
'title': title,
Expand Down
2 changes: 1 addition & 1 deletion iodata/formats/cp2klog.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _read_cp2k_uncontracted_obasis(lit: LineIterator) -> MolecularBasis:
# read the exponent
exponent = float(words[-1])
exponents.append(exponent)
coeffs.append([1.0 / _get_cp2k_norm_corrections(angmom, exponent)])
coeffs.append(1.0 / _get_cp2k_norm_corrections(angmom, exponent))
line = next(lit)
# Build the shell
kind = 'c' if angmom < 2 else 'p'
Expand Down
1 change: 1 addition & 0 deletions iodata/formats/cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def load_one(lit: LineIterator) -> dict:
"""Do not edit this docstring. It will be overwritten."""
title, atcoords, atnums, cellvecs, cube, atcorenums = _read_cube_header(lit)
_read_cube_data(lit, cube)
del cube["shape"]
return {
'title': title,
'atcoords': atcoords,
Expand Down
2 changes: 1 addition & 1 deletion iodata/formats/molden.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def _fix_obasis_normalize_contractions(obasis: MolecularBasis) -> MolecularBasis
fixed_shells = []
for shell in obasis.shells:
shell_obasis = MolecularBasis(
[shell._replace(icenter=0)],
[shell.evolve(icenter=0)],
obasis.conventions,
obasis.primitive_normalization
)
Expand Down
49 changes: 11 additions & 38 deletions iodata/iodata.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import attr
import numpy as np

from .attrutils import convert_array_to, validate_shape
from .basis import MolecularBasis
from .orbitals import MolecularOrbitals
from .utils import Cube
Expand All @@ -30,34 +31,6 @@
__all__ = ['IOData']


def convert_array_to(dtype):
"""Return a function to convert arrays to the given type."""
def converter(array):
if array is None:
return None
return np.array(array, copy=False, dtype=dtype)
return converter


def validate_shape(*shape):
"""Return a function to validate the shape of an array."""
def validator(obj, attrname, value):
if value is None:
return
myshape = tuple([obj.natom if size == 'natom' else size for size in shape])
if len(myshape) != len(value.shape):
raise TypeError('Expect ndim {} for attribute {}, got {}'.format(
len(myshape), attrname, len(value.shape)))
for axis, size in enumerate(myshape):
if size is None:
continue
if size != value.shape[axis]:
raise TypeError(
'Expect size {} for axis {} of attribute {}, got {}'.format(
size, axis, attrname, value.shape[axis]))
return validator


# pylint: disable=too-many-instance-attributes
@attr.s(auto_attribs=True, slots=True)
class IOData:
Expand Down Expand Up @@ -198,40 +171,40 @@ class IOData:
atcharges: dict = {}
atcoords: np.ndarray = attr.ib(
default=None, converter=convert_array_to(float),
validator=validate_shape('natom', 3))
validator=attr.validators.optional(validate_shape('natom', 3)))
_atcorenums: np.ndarray = attr.ib(
default=None, converter=convert_array_to(float),
validator=validate_shape('natom'))
validator=attr.validators.optional(validate_shape('natom')))
atffparams: dict = {}
atfrozen: np.ndarray = attr.ib(
default=None, converter=convert_array_to(bool),
validator=validate_shape('natom'))
validator=attr.validators.optional(validate_shape('natom')))
atgradient: np.ndarray = attr.ib(
default=None, converter=convert_array_to(float),
validator=validate_shape('natom', 3))
validator=attr.validators.optional(validate_shape('natom', 3)))
athessian: np.ndarray = attr.ib(
default=None, converter=convert_array_to(float),
validator=validate_shape(None, None))
validator=attr.validators.optional(validate_shape(None, None)))
atmasses: np.ndarray = attr.ib(
default=None, converter=convert_array_to(float),
validator=validate_shape('natom'))
validator=attr.validators.optional(validate_shape('natom')))
atnums: np.ndarray = attr.ib(
default=None, converter=convert_array_to(int),
validator=validate_shape('natom'))
validator=attr.validators.optional(validate_shape('natom')))
basisdef: str = None
bonds: np.ndarray = attr.ib(
default=None, converter=convert_array_to(int),
validator=validate_shape(None, 3))
validator=attr.validators.optional(validate_shape(None, 3)))
cellvecs: np.ndarray = attr.ib(
default=None, converter=convert_array_to(float),
validator=validate_shape(None, 3))
validator=attr.validators.optional(validate_shape(None, 3)))
_charge: float = None
core_energy: float = None
cube: Cube = None
energy: float = None
extcharges: np.ndarray = attr.ib(
default=None, converter=convert_array_to(float),
validator=validate_shape(None, 4))
validator=attr.validators.optional(validate_shape(None, 4)))
extra: dict = {}
g_rot: float = None
lot: str = None
Expand Down
Loading

0 comments on commit db297ac

Please sign in to comment.