-
Notifications
You must be signed in to change notification settings - Fork 868
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
check_params consider lower str exception
- Loading branch information
1 parent
b53d52c
commit c365099
Showing
3 changed files
with
90 additions
and
71 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -723,9 +723,69 @@ class Incar(UserDict, MSONable): | |
- Keys are stored in uppercase to allow case-insensitive access (set, get, del, update, setdefault). | ||
- String values are capitalized by default, except for keys specified | ||
in the `lower_str_keys/as_is_str_keys` of the `proc_val` method. | ||
in the `lower_str_keys/as_is_str_keys` class variables. | ||
""" | ||
|
||
list_keys: ClassVar[tuple[str, ...]] = ( | ||
"LDAUU", | ||
"LDAUL", | ||
"LDAUJ", | ||
"MAGMOM", | ||
"DIPOL", | ||
"LANGEVIN_GAMMA", | ||
"QUAD_EFG", | ||
"EINT", | ||
) | ||
bool_keys: ClassVar[tuple[str, ...]] = ( | ||
"LDAU", | ||
"LWAVE", | ||
"LSCALU", | ||
"LCHARG", | ||
"LPLANE", | ||
"LUSE_VDW", | ||
"LHFCALC", | ||
"ADDGRID", | ||
"LSORBIT", | ||
"LNONCOLLINEAR", | ||
) | ||
float_keys: ClassVar[tuple[str, ...]] = ( | ||
"EDIFF", | ||
"SIGMA", | ||
"TIME", | ||
"ENCUTFOCK", | ||
"HFSCREEN", | ||
"POTIM", | ||
"EDIFFG", | ||
"AGGAC", | ||
"PARAM1", | ||
"PARAM2", | ||
"ENCUT", | ||
) | ||
int_keys: ClassVar[tuple[str, ...]] = ( | ||
"NSW", | ||
"NBANDS", | ||
"NELMIN", | ||
"ISIF", | ||
"IBRION", | ||
"ISPIN", | ||
"ISTART", | ||
"ICHARG", | ||
"NELM", | ||
"ISMEAR", | ||
"NPAR", | ||
"LDAUPRINT", | ||
"LMAXMIX", | ||
"NSIM", | ||
"NKRED", | ||
"NUPDOWN", | ||
"ISPIND", | ||
"LDAUTYPE", | ||
"IVDW", | ||
) | ||
lower_str_keys: ClassVar[tuple[str, ...]] = ("ML_MODE",) | ||
# String keywords to read "as is" (no case transformation, only stripped) | ||
as_is_str_keys: ClassVar[tuple[str, ...]] = ("SYSTEM",) | ||
|
||
def __init__(self, params: dict[str, Any] | None = None) -> None: | ||
""" | ||
Clean up params and create an Incar object. | ||
|
@@ -905,74 +965,15 @@ def from_str(cls, string: str) -> Self: | |
params[key] = cls.proc_val(key, val) | ||
return cls(params) | ||
|
||
@staticmethod | ||
def proc_val(key: str, val: str) -> list | bool | float | int | str: | ||
@classmethod | ||
def proc_val(cls, key: str, val: str) -> list | bool | float | int | str: | ||
"""Helper method to convert INCAR parameters to proper types | ||
like ints, floats, lists, etc. | ||
Args: | ||
key (str): INCAR parameter key. | ||
val (str): Value of INCAR parameter. | ||
""" | ||
list_keys = ( | ||
"LDAUU", | ||
"LDAUL", | ||
"LDAUJ", | ||
"MAGMOM", | ||
"DIPOL", | ||
"LANGEVIN_GAMMA", | ||
"QUAD_EFG", | ||
"EINT", | ||
) | ||
bool_keys = ( | ||
"LDAU", | ||
"LWAVE", | ||
"LSCALU", | ||
"LCHARG", | ||
"LPLANE", | ||
"LUSE_VDW", | ||
"LHFCALC", | ||
"ADDGRID", | ||
"LSORBIT", | ||
"LNONCOLLINEAR", | ||
) | ||
float_keys = ( | ||
"EDIFF", | ||
"SIGMA", | ||
"TIME", | ||
"ENCUTFOCK", | ||
"HFSCREEN", | ||
"POTIM", | ||
"EDIFFG", | ||
"AGGAC", | ||
"PARAM1", | ||
"PARAM2", | ||
"ENCUT", | ||
) | ||
int_keys = ( | ||
"NSW", | ||
"NBANDS", | ||
"NELMIN", | ||
"ISIF", | ||
"IBRION", | ||
"ISPIN", | ||
"ISTART", | ||
"ICHARG", | ||
"NELM", | ||
"ISMEAR", | ||
"NPAR", | ||
"LDAUPRINT", | ||
"LMAXMIX", | ||
"NSIM", | ||
"NKRED", | ||
"NUPDOWN", | ||
"ISPIND", | ||
"LDAUTYPE", | ||
"IVDW", | ||
) | ||
lower_str_keys = ("ML_MODE",) | ||
# String keywords to read "as is" (no case transformation, only stripped) | ||
as_is_str_keys = ("SYSTEM",) | ||
|
||
def smart_int_or_float(num_str: str) -> float: | ||
"""Determine whether a string represents an integer or a float.""" | ||
|
@@ -981,7 +982,7 @@ def smart_int_or_float(num_str: str) -> float: | |
return int(num_str) | ||
|
||
try: | ||
if key in list_keys: | ||
if key in cls.list_keys: | ||
output = [] | ||
tokens = re.findall(r"(-?\d+\.?\d*)\*?(-?\d+\.?\d*)?\*?(-?\d+\.?\d*)?", val) | ||
for tok in tokens: | ||
|
@@ -993,22 +994,22 @@ def smart_int_or_float(num_str: str) -> float: | |
output.append(smart_int_or_float(tok[0])) | ||
return output | ||
|
||
if key in bool_keys: | ||
if key in cls.bool_keys: | ||
if match := re.match(r"^\.?([T|F|t|f])[A-Za-z]*\.?", val): | ||
return match[1].lower() == "t" | ||
|
||
raise ValueError(f"{key} should be a boolean type!") | ||
|
||
if key in float_keys: | ||
if key in cls.float_keys: | ||
return float(re.search(r"^-?\d*\.?\d*[e|E]?-?\d*", val)[0]) # type: ignore[index] | ||
|
||
if key in int_keys: | ||
if key in cls.int_keys: | ||
return int(re.match(r"^-?[0-9]+", val)[0]) # type: ignore[index] | ||
|
||
if key in lower_str_keys: | ||
if key in cls.lower_str_keys: | ||
return val.strip().lower() | ||
|
||
if key in as_is_str_keys: | ||
if key in cls.as_is_str_keys: | ||
return val.strip() | ||
|
||
except ValueError: | ||
|
@@ -1093,10 +1094,15 @@ def check_params(self) -> None: | |
# meaning there is recording for corresponding value | ||
if allowed_values is not None: | ||
# Note: param_type could be a Union type, e.g. "str | bool" | ||
if "str" in param_type: | ||
allowed_values = [item.capitalize() if isinstance(item, str) else item for item in allowed_values] | ||
processed_values = [] | ||
for val in allowed_values: | ||
if isinstance(val, str) and tag not in self.as_is_str_keys: | ||
processed_item = val.lower() if tag in self.lower_str_keys else val.capitalize() | ||
else: | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
DanielYang59
Author
Contributor
|
||
processed_item = val | ||
processed_values.append(processed_item) | ||
|
||
if val not in allowed_values: | ||
if val not in processed_values: | ||
warnings.warn( | ||
f"{tag}: Cannot find {val} in the list of values", | ||
BadIncarWarning, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Why not simply calling proc_val on these? This would avoid code duplication.