Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add category property to OutputVariableDef #3228

Merged
merged 14 commits into from
Feb 6, 2024
48 changes: 46 additions & 2 deletions deepmd/dpmodel/output_def.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import functools
from enum import (
IntEnum,
)
from typing import (
Dict,
List,
Expand Down Expand Up @@ -107,6 +110,34 @@ def __call__(
return wrapper


class OutputVariableOperation(IntEnum):
"""Defines the operation of the output variable."""

NONE = 0
"""No operation."""
REDU = 1
"""Reduce the output variable."""
DERV_R = 2
"""Derivative w.r.t. coordinates."""
DERV_C = 4
"""Derivative w.r.t. cell."""


class OutputVariableCategory(IntEnum):
"""Defines the category of the output variable."""

OUT = OutputVariableOperation.NONE
"""Output variable. (e.g. atom energy)"""
REDU = OutputVariableOperation.REDU
"""Reduced output variable. (e.g. system energy)"""
DERV_R = OutputVariableOperation.DERV_R
"""Negative derivative w.r.t. coordinates. (e.g. force)"""
DERV_C = OutputVariableOperation.DERV_C
"""Atomic component of the virial, see PRB 104, 224202 (2021) """
DERV_C_REDU = OutputVariableOperation.DERV_C | OutputVariableOperation.REDU
"""Virial, the transposed negative gradient with cell tensor times cell tensor, see eq 40 JCP 159, 054801 (2023). """


class OutputVariableDef:
"""Defines the shape and other properties of the one output variable.

Expand All @@ -129,7 +160,8 @@ class OutputVariableDef:
If the variable is differentiated with respect to coordinates
of atoms and cell tensor (pbc case). Only reduciable variable
are differentiable.

category : int
The category of the output variable.
"""

def __init__(
Expand All @@ -139,6 +171,7 @@ def __init__(
reduciable: bool = False,
differentiable: bool = False,
atomic: bool = True,
category: int = OutputVariableCategory.OUT.value,
):
self.name = name
self.shape = list(shape)
Expand All @@ -149,6 +182,7 @@ def __init__(
raise ValueError("only reduciable variable are differentiable")
if self.reduciable and not self.atomic:
raise ValueError("only reduciable variable should be atomic")
self.category = category


class FittingOutputDef:
Expand Down Expand Up @@ -261,9 +295,15 @@ def do_reduce(
def_redu: Dict[str, OutputVariableDef] = {}
for kk, vv in def_outp_data.items():
if vv.reduciable:
assert vv.category & OutputVariableOperation.REDU.value == 0
rk = get_reduce_name(kk)
def_redu[rk] = OutputVariableDef(
rk, vv.shape, reduciable=False, differentiable=False, atomic=False
rk,
vv.shape,
reduciable=False,
differentiable=False,
atomic=False,
category=vv.category | OutputVariableOperation.REDU.value,
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
)
return def_redu

Expand All @@ -275,19 +315,23 @@ def do_derivative(
def_derv_c: Dict[str, OutputVariableDef] = {}
for kk, vv in def_outp_data.items():
if vv.differentiable:
assert vv.category & OutputVariableOperation.DERV_R.value == 0
assert vv.category & OutputVariableOperation.DERV_C.value == 0
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
rkr, rkc = get_deriv_name(kk)
def_derv_r[rkr] = OutputVariableDef(
rkr,
vv.shape + [3], # noqa: RUF005
reduciable=False,
differentiable=False,
atomic=True,
category=vv.category | OutputVariableOperation.DERV_R.value,
)
def_derv_c[rkc] = OutputVariableDef(
rkc,
vv.shape + [3, 3], # noqa: RUF005
reduciable=True,
differentiable=False,
atomic=True,
category=vv.category | OutputVariableOperation.DERV_C.value,
)
return def_derv_r, def_derv_c
55 changes: 55 additions & 0 deletions source/tests/common/test_output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
model_check_output,
)
from deepmd.dpmodel.output_def import (
OutputVariableCategory,
OutputVariableOperation,
check_var,
)

Expand Down Expand Up @@ -103,6 +105,59 @@ def test_model_output_def(self):
self.assertEqual(md["energy_derv_r"].atomic, True)
self.assertEqual(md["energy_derv_c"].atomic, True)
self.assertEqual(md["energy_derv_c_redu"].atomic, False)
# category
self.assertEqual(md["energy"].category, OutputVariableCategory.OUT)
self.assertEqual(md["dos"].category, OutputVariableCategory.OUT)
self.assertEqual(md["foo"].category, OutputVariableCategory.OUT)
self.assertEqual(md["energy_redu"].category, OutputVariableCategory.REDU)
self.assertEqual(md["energy_derv_r"].category, OutputVariableCategory.DERV_R)
self.assertEqual(md["energy_derv_c"].category, OutputVariableCategory.DERV_C)
self.assertEqual(
md["energy_derv_c_redu"].category, OutputVariableCategory.DERV_C_REDU
)
# flag
self.assertEqual(md["energy"].category & OutputVariableOperation.REDU, 0)
self.assertEqual(md["energy"].category & OutputVariableOperation.DERV_R, 0)
self.assertEqual(md["energy"].category & OutputVariableOperation.DERV_C, 0)
self.assertEqual(md["dos"].category & OutputVariableOperation.REDU, 0)
self.assertEqual(md["dos"].category & OutputVariableOperation.DERV_R, 0)
self.assertEqual(md["dos"].category & OutputVariableOperation.DERV_C, 0)
self.assertEqual(md["foo"].category & OutputVariableOperation.REDU, 0)
self.assertEqual(md["foo"].category & OutputVariableOperation.DERV_R, 0)
self.assertEqual(md["foo"].category & OutputVariableOperation.DERV_C, 0)
self.assertEqual(
md["energy_redu"].category & OutputVariableOperation.REDU,
OutputVariableOperation.REDU,
)
self.assertEqual(md["energy_redu"].category & OutputVariableOperation.DERV_R, 0)
self.assertEqual(md["energy_redu"].category & OutputVariableOperation.DERV_C, 0)
self.assertEqual(md["energy_derv_r"].category & OutputVariableOperation.REDU, 0)
self.assertEqual(
md["energy_derv_r"].category & OutputVariableOperation.DERV_R,
OutputVariableOperation.DERV_R,
)
self.assertEqual(
md["energy_derv_r"].category & OutputVariableOperation.DERV_C, 0
)
self.assertEqual(md["energy_derv_c"].category & OutputVariableOperation.REDU, 0)
self.assertEqual(
md["energy_derv_c"].category & OutputVariableOperation.DERV_R, 0
)
self.assertEqual(
md["energy_derv_c"].category & OutputVariableOperation.DERV_C,
OutputVariableOperation.DERV_C,
)
self.assertEqual(
md["energy_derv_c_redu"].category & OutputVariableOperation.REDU,
OutputVariableOperation.REDU,
)
self.assertEqual(
md["energy_derv_c_redu"].category & OutputVariableOperation.DERV_R, 0
)
self.assertEqual(
md["energy_derv_c_redu"].category & OutputVariableOperation.DERV_C,
OutputVariableOperation.DERV_C,
)

def test_raise_no_redu_deriv(self):
with self.assertRaises(ValueError) as context:
Expand Down
Loading