Skip to content

Commit

Permalink
check_params consider lower str exception
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielYang59 committed Oct 27, 2024
1 parent b53d52c commit c365099
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 71 deletions.
11 changes: 11 additions & 0 deletions src/pymatgen/io/vasp/incar_parameters.json
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,17 @@
"ML_FF_WTSIF": {
"type": "float"
},
"ML_MODE": {
"type": "str",
"values": [
"train",
"select",
"refit",
"refitbayesian",
"run",
"none"
]
},
"NBANDS": {
"type": "int"
},
Expand Down
148 changes: 77 additions & 71 deletions src/pymatgen/io/vasp/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,9 +723,69 @@ class Incar(UserDict, MSONable):
- Keys are stored in uppercase to allow case-insensitive access (set, get, del, update, setdefault).
- String values are capitalized by default, except for keys specified
in the `lower_str_keys/as_is_str_keys` of the `proc_val` method.
in the `lower_str_keys/as_is_str_keys` class variables.
"""

list_keys: ClassVar[tuple[str, ...]] = (
"LDAUU",
"LDAUL",
"LDAUJ",
"MAGMOM",
"DIPOL",
"LANGEVIN_GAMMA",
"QUAD_EFG",
"EINT",
)
bool_keys: ClassVar[tuple[str, ...]] = (
"LDAU",
"LWAVE",
"LSCALU",
"LCHARG",
"LPLANE",
"LUSE_VDW",
"LHFCALC",
"ADDGRID",
"LSORBIT",
"LNONCOLLINEAR",
)
float_keys: ClassVar[tuple[str, ...]] = (
"EDIFF",
"SIGMA",
"TIME",
"ENCUTFOCK",
"HFSCREEN",
"POTIM",
"EDIFFG",
"AGGAC",
"PARAM1",
"PARAM2",
"ENCUT",
)
int_keys: ClassVar[tuple[str, ...]] = (
"NSW",
"NBANDS",
"NELMIN",
"ISIF",
"IBRION",
"ISPIN",
"ISTART",
"ICHARG",
"NELM",
"ISMEAR",
"NPAR",
"LDAUPRINT",
"LMAXMIX",
"NSIM",
"NKRED",
"NUPDOWN",
"ISPIND",
"LDAUTYPE",
"IVDW",
)
lower_str_keys: ClassVar[tuple[str, ...]] = ("ML_MODE",)
# String keywords to read "as is" (no case transformation, only stripped)
as_is_str_keys: ClassVar[tuple[str, ...]] = ("SYSTEM",)

def __init__(self, params: dict[str, Any] | None = None) -> None:
"""
Clean up params and create an Incar object.
Expand Down Expand Up @@ -905,74 +965,15 @@ def from_str(cls, string: str) -> Self:
params[key] = cls.proc_val(key, val)
return cls(params)

@staticmethod
def proc_val(key: str, val: str) -> list | bool | float | int | str:
@classmethod
def proc_val(cls, key: str, val: str) -> list | bool | float | int | str:
"""Helper method to convert INCAR parameters to proper types
like ints, floats, lists, etc.
Args:
key (str): INCAR parameter key.
val (str): Value of INCAR parameter.
"""
list_keys = (
"LDAUU",
"LDAUL",
"LDAUJ",
"MAGMOM",
"DIPOL",
"LANGEVIN_GAMMA",
"QUAD_EFG",
"EINT",
)
bool_keys = (
"LDAU",
"LWAVE",
"LSCALU",
"LCHARG",
"LPLANE",
"LUSE_VDW",
"LHFCALC",
"ADDGRID",
"LSORBIT",
"LNONCOLLINEAR",
)
float_keys = (
"EDIFF",
"SIGMA",
"TIME",
"ENCUTFOCK",
"HFSCREEN",
"POTIM",
"EDIFFG",
"AGGAC",
"PARAM1",
"PARAM2",
"ENCUT",
)
int_keys = (
"NSW",
"NBANDS",
"NELMIN",
"ISIF",
"IBRION",
"ISPIN",
"ISTART",
"ICHARG",
"NELM",
"ISMEAR",
"NPAR",
"LDAUPRINT",
"LMAXMIX",
"NSIM",
"NKRED",
"NUPDOWN",
"ISPIND",
"LDAUTYPE",
"IVDW",
)
lower_str_keys = ("ML_MODE",)
# String keywords to read "as is" (no case transformation, only stripped)
as_is_str_keys = ("SYSTEM",)

def smart_int_or_float(num_str: str) -> float:
"""Determine whether a string represents an integer or a float."""
Expand All @@ -981,7 +982,7 @@ def smart_int_or_float(num_str: str) -> float:
return int(num_str)

try:
if key in list_keys:
if key in cls.list_keys:
output = []
tokens = re.findall(r"(-?\d+\.?\d*)\*?(-?\d+\.?\d*)?\*?(-?\d+\.?\d*)?", val)
for tok in tokens:
Expand All @@ -993,22 +994,22 @@ def smart_int_or_float(num_str: str) -> float:
output.append(smart_int_or_float(tok[0]))
return output

if key in bool_keys:
if key in cls.bool_keys:
if match := re.match(r"^\.?([T|F|t|f])[A-Za-z]*\.?", val):
return match[1].lower() == "t"

raise ValueError(f"{key} should be a boolean type!")

if key in float_keys:
if key in cls.float_keys:
return float(re.search(r"^-?\d*\.?\d*[e|E]?-?\d*", val)[0]) # type: ignore[index]

if key in int_keys:
if key in cls.int_keys:
return int(re.match(r"^-?[0-9]+", val)[0]) # type: ignore[index]

if key in lower_str_keys:
if key in cls.lower_str_keys:
return val.strip().lower()

if key in as_is_str_keys:
if key in cls.as_is_str_keys:
return val.strip()

except ValueError:
Expand Down Expand Up @@ -1093,10 +1094,15 @@ def check_params(self) -> None:
# meaning there is recording for corresponding value
if allowed_values is not None:
# Note: param_type could be a Union type, e.g. "str | bool"
if "str" in param_type:
allowed_values = [item.capitalize() if isinstance(item, str) else item for item in allowed_values]
processed_values = []
for val in allowed_values:
if isinstance(val, str) and tag not in self.as_is_str_keys:
processed_item = val.lower() if tag in self.lower_str_keys else val.capitalize()
else:

This comment has been minimized.

Copy link
@yantar92

yantar92 Oct 27, 2024

Contributor

Why not simply calling proc_val on these? This would avoid code duplication.

This comment has been minimized.

Copy link
@DanielYang59

DanielYang59 Oct 27, 2024

Author Contributor

That is a MUCH MUCH MUCH smarter way than mine, sorry I got confused by working at several PRs at the same time

processed_item = val
processed_values.append(processed_item)

if val not in allowed_values:
if val not in processed_values:
warnings.warn(
f"{tag}: Cannot find {val} in the list of values",
BadIncarWarning,
Expand Down
2 changes: 2 additions & 0 deletions tests/io/vasp/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,8 @@ def test_check_params(self):
"AMIN": 0.01,
"ICHARG": 1,
"MAGMOM": [1, 2, 4, 5],
"ML_MODE": "RUN", # lower case string
"SYSTEM": "Hello world", # as is string
"ENCUT": 500, # make sure float key is casted
"GGA": "PS", # test string case insensitivity
"LREAL": True, # special case: Union type
Expand Down

0 comments on commit c365099

Please sign in to comment.