Skip to content

Commit

Permalink
feat(pt): add universal test for loss (#4354)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

- **New Features**
- Introduced a new `LossTest` class for enhanced testing of loss
functions.
- Added multiple parameterized test functions for various loss functions
in the new `test_loss.py` file.
  
- **Bug Fixes**
- Corrected tensor operations in the `DOSLoss` class to ensure accurate
cumulative sum calculations.

- **Documentation**
- Added SPDX license identifiers to multiple files for clarity on
licensing terms.

- **Chores**
- Refactored data conversion methods in the `PTTestCase` class for
improved handling of tensors and arrays.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com>
Co-authored-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
iProzd and njzjz authored Nov 14, 2024
1 parent d3095cf commit 0ad4289
Show file tree
Hide file tree
Showing 10 changed files with 388 additions and 11 deletions.
4 changes: 2 additions & 2 deletions deepmd/dpmodel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
"double": np.float64,
"int32": np.int32,
"int64": np.int64,
"bool": bool,
"bool": np.bool_,
"default": GLOBAL_NP_FLOAT_PRECISION,
# NumPy doesn't have bfloat16 (and doesn't plan to add)
# ml_dtypes is a solution, but it seems not supporting np.save/np.load
Expand All @@ -50,7 +50,7 @@
np.int32: "int32",
np.int64: "int64",
ml_dtypes.bfloat16: "bfloat16",
bool: "bool",
np.bool_: "bool",
}
assert set(RESERVED_PRECISON_DICT.keys()) == set(PRECISION_DICT.values())
DEFAULT_PRECISION = "float64"
Expand Down
8 changes: 4 additions & 4 deletions deepmd/pt/loss/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,10 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False
if self.has_acdf and "atom_dos" in model_pred and "atom_dos" in label:
find_local = label.get("find_atom_dos", 0.0)
pref_acdf = pref_acdf * find_local
local_tensor_pred_cdf = torch.cusum(
local_tensor_pred_cdf = torch.cumsum(
model_pred["atom_dos"].reshape([-1, natoms, self.numb_dos]), dim=-1
)
local_tensor_label_cdf = torch.cusum(
local_tensor_label_cdf = torch.cumsum(
label["atom_dos"].reshape([-1, natoms, self.numb_dos]), dim=-1
)
diff = (local_tensor_pred_cdf - local_tensor_label_cdf).reshape(
Expand Down Expand Up @@ -199,10 +199,10 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False
if self.has_cdf and "dos" in model_pred and "dos" in label:
find_global = label.get("find_dos", 0.0)
pref_cdf = pref_cdf * find_global
global_tensor_pred_cdf = torch.cusum(
global_tensor_pred_cdf = torch.cumsum(
model_pred["dos"].reshape([-1, self.numb_dos]), dim=-1
)
global_tensor_label_cdf = torch.cusum(
global_tensor_label_cdf = torch.cumsum(
label["dos"].reshape([-1, self.numb_dos]), dim=-1
)
diff = global_tensor_pred_cdf - global_tensor_label_cdf
Expand Down
1 change: 1 addition & 0 deletions source/tests/universal/common/cases/loss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
11 changes: 11 additions & 0 deletions source/tests/universal/common/cases/loss/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-License-Identifier: LGPL-3.0-or-later


from .utils import (
LossTestCase,
)


class LossTest(LossTestCase):
def setUp(self) -> None:
LossTestCase.setUp(self)
79 changes: 79 additions & 0 deletions source/tests/universal/common/cases/loss/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import numpy as np

from deepmd.utils.data import (
DataRequirementItem,
)

from .....seed import (
GLOBAL_SEED,
)


class LossTestCase:
"""Common test case for loss function."""

def setUp(self):
pass

def test_label_keys(self):
module = self.forward_wrapper(self.module)
label_requirement = self.module.label_requirement
label_dict = {item.key: item for item in label_requirement}
label_keys = sorted(label_dict.keys())
label_keys_expected = sorted(
[key for key in self.key_to_pref_map if self.key_to_pref_map[key] > 0]
)
np.testing.assert_equal(label_keys_expected, label_keys)

def test_forward(self):
module = self.forward_wrapper(self.module)
label_requirement = self.module.label_requirement
label_dict = {item.key: item for item in label_requirement}
label_keys = sorted(label_dict.keys())
natoms = 5
nframes = 2

def fake_model():
model_predict = {
data_key: fake_input(
label_dict[data_key], natoms=natoms, nframes=nframes
)
for data_key in label_keys
}
if "atom_ener" in model_predict:
model_predict["atom_energy"] = model_predict.pop("atom_ener")
model_predict.update(
{"mask_mag": np.ones([nframes, natoms, 1], dtype=np.bool_)}
)
return model_predict

labels = {
data_key: fake_input(label_dict[data_key], natoms=natoms, nframes=nframes)
for data_key in label_keys
}
labels.update({"find_" + data_key: 1.0 for data_key in label_keys})

_, loss, more_loss = module(
{},
fake_model,
labels,
natoms,
1.0,
)


def fake_input(data_item: DataRequirementItem, natoms=5, nframes=2) -> np.ndarray:
ndof = data_item.ndof
atomic = data_item.atomic
repeat = data_item.repeat
rng = np.random.default_rng(seed=GLOBAL_SEED)
dtype = data_item.dtype if data_item.dtype is not None else np.float64
if atomic:
data = rng.random([nframes, natoms, ndof], dtype)
else:
data = rng.random([nframes, ndof], dtype)
if repeat != 1:
data = np.repeat(data, repeat).reshape([nframes, -1])
return data
1 change: 1 addition & 0 deletions source/tests/universal/dpmodel/loss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
203 changes: 203 additions & 0 deletions source/tests/universal/dpmodel/loss/test_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from collections import (
OrderedDict,
)

from ....consistent.common import (
parameterize_func,
)


def LossParamEnergy(
starter_learning_rate=1.0,
pref_e=1.0,
pref_f=1.0,
pref_v=1.0,
pref_ae=1.0,
):
key_to_pref_map = {
"energy": pref_e,
"force": pref_f,
"virial": pref_v,
"atom_ener": pref_ae,
}
input_dict = {
"key_to_pref_map": key_to_pref_map,
"starter_learning_rate": starter_learning_rate,
"start_pref_e": pref_e,
"limit_pref_e": pref_e / 2,
"start_pref_f": pref_f,
"limit_pref_f": pref_f / 2,
"start_pref_v": pref_v,
"limit_pref_v": pref_v / 2,
"start_pref_ae": pref_ae,
"limit_pref_ae": pref_ae / 2,
}
return input_dict


LossParamEnergyList = parameterize_func(
LossParamEnergy,
OrderedDict(
{
"pref_e": (1.0, 0.0),
"pref_f": (1.0, 0.0),
"pref_v": (1.0, 0.0),
"pref_ae": (1.0, 0.0),
}
),
)
# to get name for the default function
LossParamEnergy = LossParamEnergyList[0]


def LossParamEnergySpin(
starter_learning_rate=1.0,
pref_e=1.0,
pref_fr=1.0,
pref_fm=1.0,
pref_v=1.0,
pref_ae=1.0,
):
key_to_pref_map = {
"energy": pref_e,
"force": pref_fr,
"force_mag": pref_fm,
"virial": pref_v,
"atom_ener": pref_ae,
}
input_dict = {
"key_to_pref_map": key_to_pref_map,
"starter_learning_rate": starter_learning_rate,
"start_pref_e": pref_e,
"limit_pref_e": pref_e / 2,
"start_pref_fr": pref_fr,
"limit_pref_fr": pref_fr / 2,
"start_pref_fm": pref_fm,
"limit_pref_fm": pref_fm / 2,
"start_pref_v": pref_v,
"limit_pref_v": pref_v / 2,
"start_pref_ae": pref_ae,
"limit_pref_ae": pref_ae / 2,
}
return input_dict


LossParamEnergySpinList = parameterize_func(
LossParamEnergySpin,
OrderedDict(
{
"pref_e": (1.0, 0.0),
"pref_fr": (1.0, 0.0),
"pref_fm": (1.0, 0.0),
"pref_v": (1.0, 0.0),
"pref_ae": (1.0, 0.0),
}
),
)
# to get name for the default function
LossParamEnergySpin = LossParamEnergySpinList[0]


def LossParamDos(
starter_learning_rate=1.0,
pref_dos=1.0,
pref_ados=1.0,
):
key_to_pref_map = {
"dos": pref_dos,
"atom_dos": pref_ados,
}
input_dict = {
"key_to_pref_map": key_to_pref_map,
"starter_learning_rate": starter_learning_rate,
"numb_dos": 2,
"start_pref_dos": pref_dos,
"limit_pref_dos": pref_dos / 2,
"start_pref_ados": pref_ados,
"limit_pref_ados": pref_ados / 2,
"start_pref_cdf": 0.0,
"limit_pref_cdf": 0.0,
"start_pref_acdf": 0.0,
"limit_pref_acdf": 0.0,
}
return input_dict


LossParamDosList = parameterize_func(
LossParamDos,
OrderedDict(
{
"pref_dos": (1.0,),
"pref_ados": (1.0, 0.0),
}
),
) + parameterize_func(
LossParamDos,
OrderedDict(
{
"pref_dos": (0.0,),
"pref_ados": (1.0,),
}
),
)

# to get name for the default function
LossParamDos = LossParamDosList[0]


def LossParamTensor(
pref=1.0,
pref_atomic=1.0,
):
tensor_name = "test_tensor"
key_to_pref_map = {
tensor_name: pref,
f"atomic_{tensor_name}": pref_atomic,
}
input_dict = {
"key_to_pref_map": key_to_pref_map,
"tensor_name": tensor_name,
"tensor_size": 2,
"label_name": tensor_name,
"pref": pref,
"pref_atomic": pref_atomic,
}
return input_dict


LossParamTensorList = parameterize_func(
LossParamTensor,
OrderedDict(
{
"pref": (1.0,),
"pref_atomic": (1.0, 0.0),
}
),
) + parameterize_func(
LossParamTensor,
OrderedDict(
{
"pref": (0.0,),
"pref_atomic": (1.0,),
}
),
)
# to get name for the default function
LossParamTensor = LossParamTensorList[0]


def LossParamProperty():
key_to_pref_map = {
"property": 1.0,
}
input_dict = {
"key_to_pref_map": key_to_pref_map,
"task_dim": 2,
}
return input_dict


LossParamPropertyList = [LossParamProperty]
# to get name for the default function
LossParamProperty = LossParamPropertyList[0]
Loading

0 comments on commit 0ad4289

Please sign in to comment.