Skip to content

Commit

Permalink
reuse proc_val in check_params, thanks @yantar92
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielYang59 committed Oct 27, 2024
1 parent c365099 commit 43113fd
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 79 deletions.
149 changes: 72 additions & 77 deletions src/pymatgen/io/vasp/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,69 +723,9 @@ class Incar(UserDict, MSONable):
- Keys are stored in uppercase to allow case-insensitive access (set, get, del, update, setdefault).
- String values are capitalized by default, except for keys specified
in the `lower_str_keys/as_is_str_keys` class variables.
in the `lower_str_keys/as_is_str_keys` in `proc_val` method.
"""

list_keys: ClassVar[tuple[str, ...]] = (
"LDAUU",
"LDAUL",
"LDAUJ",
"MAGMOM",
"DIPOL",
"LANGEVIN_GAMMA",
"QUAD_EFG",
"EINT",
)
bool_keys: ClassVar[tuple[str, ...]] = (
"LDAU",
"LWAVE",
"LSCALU",
"LCHARG",
"LPLANE",
"LUSE_VDW",
"LHFCALC",
"ADDGRID",
"LSORBIT",
"LNONCOLLINEAR",
)
float_keys: ClassVar[tuple[str, ...]] = (
"EDIFF",
"SIGMA",
"TIME",
"ENCUTFOCK",
"HFSCREEN",
"POTIM",
"EDIFFG",
"AGGAC",
"PARAM1",
"PARAM2",
"ENCUT",
)
int_keys: ClassVar[tuple[str, ...]] = (
"NSW",
"NBANDS",
"NELMIN",
"ISIF",
"IBRION",
"ISPIN",
"ISTART",
"ICHARG",
"NELM",
"ISMEAR",
"NPAR",
"LDAUPRINT",
"LMAXMIX",
"NSIM",
"NKRED",
"NUPDOWN",
"ISPIND",
"LDAUTYPE",
"IVDW",
)
lower_str_keys: ClassVar[tuple[str, ...]] = ("ML_MODE",)
# String keywords to read "as is" (no case transformation, only stripped)
as_is_str_keys: ClassVar[tuple[str, ...]] = ("SYSTEM",)

def __init__(self, params: dict[str, Any] | None = None) -> None:
"""
Clean up params and create an Incar object.
Expand Down Expand Up @@ -965,15 +905,74 @@ def from_str(cls, string: str) -> Self:
params[key] = cls.proc_val(key, val)
return cls(params)

@classmethod
def proc_val(cls, key: str, val: str) -> list | bool | float | int | str:
@staticmethod
def proc_val(key: str, val: str) -> list | bool | float | int | str:
"""Helper method to convert INCAR parameters to proper types
like ints, floats, lists, etc.
Args:
key (str): INCAR parameter key.
val (str): Value of INCAR parameter.
"""
list_keys = (
"LDAUU",
"LDAUL",
"LDAUJ",
"MAGMOM",
"DIPOL",
"LANGEVIN_GAMMA",
"QUAD_EFG",
"EINT",
)
bool_keys = (
"LDAU",
"LWAVE",
"LSCALU",
"LCHARG",
"LPLANE",
"LUSE_VDW",
"LHFCALC",
"ADDGRID",
"LSORBIT",
"LNONCOLLINEAR",
)
float_keys = (
"EDIFF",
"SIGMA",
"TIME",
"ENCUTFOCK",
"HFSCREEN",
"POTIM",
"EDIFFG",
"AGGAC",
"PARAM1",
"PARAM2",
"ENCUT",
)
int_keys = (
"NSW",
"NBANDS",
"NELMIN",
"ISIF",
"IBRION",
"ISPIN",
"ISTART",
"ICHARG",
"NELM",
"ISMEAR",
"NPAR",
"LDAUPRINT",
"LMAXMIX",
"NSIM",
"NKRED",
"NUPDOWN",
"ISPIND",
"LDAUTYPE",
"IVDW",
)
lower_str_keys = ("ML_MODE",)
# String keywords to read "as is" (no case transformation, only stripped)
as_is_str_keys = ("SYSTEM",)

def smart_int_or_float(num_str: str) -> float:
"""Determine whether a string represents an integer or a float."""
Expand All @@ -982,7 +981,7 @@ def smart_int_or_float(num_str: str) -> float:
return int(num_str)

try:
if key in cls.list_keys:
if key in list_keys:
output = []
tokens = re.findall(r"(-?\d+\.?\d*)\*?(-?\d+\.?\d*)?\*?(-?\d+\.?\d*)?", val)
for tok in tokens:
Expand All @@ -994,22 +993,22 @@ def smart_int_or_float(num_str: str) -> float:
output.append(smart_int_or_float(tok[0]))
return output

if key in cls.bool_keys:
if key in bool_keys:
if match := re.match(r"^\.?([T|F|t|f])[A-Za-z]*\.?", val):
return match[1].lower() == "t"

raise ValueError(f"{key} should be a boolean type!")

if key in cls.float_keys:
if key in float_keys:
return float(re.search(r"^-?\d*\.?\d*[e|E]?-?\d*", val)[0]) # type: ignore[index]

if key in cls.int_keys:
if key in int_keys:
return int(re.match(r"^-?[0-9]+", val)[0]) # type: ignore[index]

if key in cls.lower_str_keys:
if key in lower_str_keys:
return val.strip().lower()

if key in cls.as_is_str_keys:
if key in as_is_str_keys:
return val.strip()

except ValueError:
Expand Down Expand Up @@ -1094,15 +1093,11 @@ def check_params(self) -> None:
# meaning there is recording for corresponding value
if allowed_values is not None:
# Note: param_type could be a Union type, e.g. "str | bool"
processed_values = []
for val in allowed_values:
if isinstance(val, str) and tag not in self.as_is_str_keys:
processed_item = val.lower() if tag in self.lower_str_keys else val.capitalize()
else:
processed_item = val
processed_values.append(processed_item)
allowed_values = [
self.proc_val(tag, item) if isinstance(item, str) else item for item in allowed_values
]

if val not in processed_values:
if val not in allowed_values:
warnings.warn(
f"{tag}: Cannot find {val} in the list of values",
BadIncarWarning,
Expand Down
4 changes: 2 additions & 2 deletions tests/io/vasp/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,8 +895,8 @@ def test_check_params(self):
"LREAL": True, # special case: Union type
"NBAND": 250, # typo in tag
"METAGGA": "SCAM", # typo in value
"EDIFF": 5 + 1j, # value should be a float
"ISIF": 9, # value out of range
"EDIFF": 5 + 1j, # value should be float
"ISIF": 9, # value not unknown
"LASPH": 5, # value should be bool
"PHON_TLIST": "is_a_str", # value should be a list
}
Expand Down

0 comments on commit 43113fd

Please sign in to comment.