Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add category property to OutputVariableDef #3228

Merged
merged 14 commits into from
Feb 6, 2024
27 changes: 16 additions & 11 deletions deepmd/dpmodel/output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,23 @@
-------
int
The new category of the variable definition.

Raises
------
ValueError
If the operation has been applied to the variable definition,
and exceed the maximum limitation.
"""
if op == OutputVariableOperation.REDU or op == OutputVariableOperation.DERV_C:
if check_operation_applied(var_def, op):
raise ValueError(f"operation {op} has been applied")
elif op == OutputVariableOperation.DERV_R:
if check_operation_applied(var_def, OutputVariableOperation.DERV_R):
op = OutputVariableOperation.SEC_DERV_R
if check_operation_applied(var_def, OutputVariableOperation.SEC_DERV_R):
raise ValueError(f"operation {op} has been applied twice")
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
else:
raise ValueError(f"operation {op} not supported")

Check warning on line 326 in deepmd/dpmodel/output_def.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/output_def.py#L326

Added line #L326 was not covered by tests
return var_def.category | op.value


Expand All @@ -328,17 +344,6 @@
bool
True if the operation has been applied, False otherwise.
"""
if op in (OutputVariableOperation.DERV_REDU, OutputVariableOperation.DERV_C):
assert not check_operation_applied(var_def, op)
elif op == OutputVariableOperation.DERV_R:
if check_operation_applied(var_def, OutputVariableOperation.DERV_R):
op = OutputVariableOperation.SEC_DERV_R
else:
assert not check_operation_applied(
var_def, OutputVariableOperation.SEC_DERV_R
)
else:
raise ValueError(f"operation {op} not supported")
return var_def.category & op.value == op.value


Expand Down
43 changes: 43 additions & 0 deletions source/tests/common/dpmodel/test_output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from deepmd.dpmodel.output_def import (
OutputVariableCategory,
OutputVariableOperation,
apply_operation,
check_var,
)

Expand Down Expand Up @@ -159,6 +160,48 @@ def test_model_output_def(self):
OutputVariableOperation.DERV_C,
)

# apply_operation
self.assertEqual(
apply_operation(md["energy"], OutputVariableOperation.REDU),
md["energy_redu"].category,
)
self.assertEqual(
apply_operation(md["energy"], OutputVariableOperation.DERV_R),
md["energy_derv_r"].category,
)
self.assertEqual(
apply_operation(md["energy"], OutputVariableOperation.DERV_C),
md["energy_derv_c"].category,
)
self.assertEqual(
apply_operation(md["energy_derv_c"], OutputVariableOperation.REDU),
md["energy_derv_c_redu"].category,
)
# raise ValueError
with self.assertRaises(ValueError):
apply_operation(md["energy_redu"], OutputVariableOperation.REDU)
with self.assertRaises(ValueError):
apply_operation(md["energy_derv_c"], OutputVariableOperation.DERV_C)
with self.assertRaises(ValueError):
apply_operation(md["energy_derv_c_redu"], OutputVariableOperation.REDU)
# hession
hession_cat = apply_operation(
md["energy_derv_r"], OutputVariableOperation.DERV_R
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
)
self.assertEqual(
hession_cat & OutputVariableOperation.DERV_R, OutputVariableOperation.DERV_R
)
self.assertEqual(
hession_cat & OutputVariableOperation.SEC_DERV_R,
OutputVariableOperation.SEC_DERV_R,
)
self.assertEqual(hession_cat, OutputVariableCategory.DERV_R_DERV_R)
hession_vardef = OutputVariableDef(
"energy_derv_r_derv_r", [1], False, False, category=hession_cat
)
with self.assertRaises(ValueError):
apply_operation(hession_vardef, OutputVariableOperation.DERV_R)

def test_raise_no_redu_deriv(self):
with self.assertRaises(ValueError) as context:
(OutputVariableDef("energy", [1], False, True),)
Expand Down