Skip to content

Commit

Permalink
[BugFix] Subclass Construction Locpot<:VolumetricData (#3639)
Browse files Browse the repository at this point in the history
* kwargs

* test

* types + doc str fixes

---------

Co-authored-by: Janosh Riebesell <janosh.riebesell@gmail.com>
  • Loading branch information
jmmshn and janosh authored Feb 21, 2024
1 parent ae36a84 commit 3f89175
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 19 deletions.
14 changes: 8 additions & 6 deletions pymatgen/io/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,23 @@ class VolumetricData(MSONable):
ngridpts (int): Total number of grid points in volumetric data.
"""

def __init__(self, structure: Structure, data, distance_matrix=None, data_aug=None):
def __init__(
self, structure: Structure, data: np.ndarray, distance_matrix: np.ndarray = None, data_aug: np.ndarray = None
) -> None:
"""
Typically, this constructor is not used directly and the static
from_file constructor is used. This constructor is designed to allow
summation and other operations between VolumetricData objects.
Args:
structure: Structure associated with the volumetric data
data: Actual volumetric data. If the data is provided as in list format,
structure (Structure): associated with the volumetric data
data (np.array): Actual volumetric data. If the data is provided as in list format,
it will be converted into an np.array automatically
data_aug: Any extra information associated with volumetric data
(typically augmentation charges)
distance_matrix: A pre-computed distance matrix if available.
distance_matrix (np.array): A pre-computed distance matrix if available.
Useful so pass distance_matrices between sums,
short-circuiting an otherwise expensive operation.
data_aug (np.array): Any extra information associated with volumetric data
(typically augmentation charges)
"""
self.structure = structure
self.is_spin_polarized = len(data) >= 2
Expand Down
26 changes: 13 additions & 13 deletions pymatgen/io/vasp/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3251,7 +3251,7 @@ class VolumetricData(BaseVolumetricData):
"""

@staticmethod
def parse_file(filename):
def parse_file(filename: str) -> tuple[Poscar, dict, dict]:
"""
Convenience method to parse a generic volumetric data file in the vasp
like format. Used by subclasses for parsing file.
Expand All @@ -3260,17 +3260,17 @@ def parse_file(filename):
filename (str): Path of file to parse
Returns:
(poscar, data)
tuple[Poscar, dict, dict]: Poscar object, data dict, data_aug dict
"""

poscar_read = False
poscar_string = []
dataset = []
all_dataset = []
poscar_string: list[str] = []
dataset: np.ndarray = np.zeros((1, 1, 1))
all_dataset: list[np.ndarray] = []
# for holding any strings in input that are not Poscar
# or VolumetricData (typically augmentation charges)
all_dataset_aug = {}
dim = dimline = None
all_dataset_aug: dict[int, list[str]] = {}
dim: list[int] = []
dimline = ""
read_dataset = False
ngrid_pts = 0
data_count = 0
Expand Down Expand Up @@ -3354,7 +3354,7 @@ def parse_file(filename):
else:
data = {"total": all_dataset[0]}
data_aug = {"total": all_dataset_aug.get(0)}
return poscar, data, data_aug
return poscar, data, data_aug # type: ignore[return-value]

def write_file(self, file_name: str | Path, vasp4_compatible: bool = False) -> None:
"""
Expand Down Expand Up @@ -3434,13 +3434,13 @@ def write_spin(data_type):
class Locpot(VolumetricData):
"""Simple object for reading a LOCPOT file."""

def __init__(self, poscar, data):
def __init__(self, poscar: Poscar, data: np.ndarray, **kwargs) -> None:
"""
Args:
poscar (Poscar): Poscar object containing structure.
data: Actual data.
data (np.ndarray): Actual data.
"""
super().__init__(poscar.structure, data)
super().__init__(poscar.structure, data, **kwargs)
self.name = poscar.comment

@classmethod
Expand All @@ -3453,7 +3453,7 @@ def from_file(cls, filename, **kwargs):
Returns:
Locpot
"""
(poscar, data, _data_aug) = VolumetricData.parse_file(filename)
poscar, data, _data_aug = VolumetricData.parse_file(filename)
return cls(poscar, data, **kwargs)


Expand Down
5 changes: 5 additions & 0 deletions tests/io/vasp/test_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1349,6 +1349,11 @@ def test_init(self):
assert locpot.get_axis_grid(1)[-1] == approx(2.87629, abs=1e-2)
assert locpot.get_axis_grid(2)[-1] == approx(2.87629, abs=1e-2)

# make sure locpot constructor works with data_aug=None
poscar, data, _data_aug = Locpot.parse_file(filepath)
l2 = Locpot(poscar=poscar, data=data, data_aug=None)
assert l2.data_aug == {}


class TestChgcar(PymatgenTest):
@classmethod
Expand Down

0 comments on commit 3f89175

Please sign in to comment.