From 51887dbca01d5632d513d0244f70c24c680fbcf9 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 13 Nov 2024 00:31:06 +0800 Subject: [PATCH 1/4] feat(pt): add universal test for loss --- .../universal/common/cases/loss/__init__.py | 1 + .../tests/universal/common/cases/loss/loss.py | 11 +++ .../universal/common/cases/loss/utils.py | 73 +++++++++++++++ source/tests/universal/pt/backend.py | 44 ++++++++-- source/tests/universal/pt/loss/__init__.py | 1 + source/tests/universal/pt/loss/test_loss.py | 88 +++++++++++++++++++ 6 files changed, 213 insertions(+), 5 deletions(-) create mode 100644 source/tests/universal/common/cases/loss/__init__.py create mode 100644 source/tests/universal/common/cases/loss/loss.py create mode 100644 source/tests/universal/common/cases/loss/utils.py create mode 100644 source/tests/universal/pt/loss/__init__.py create mode 100644 source/tests/universal/pt/loss/test_loss.py diff --git a/source/tests/universal/common/cases/loss/__init__.py b/source/tests/universal/common/cases/loss/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/universal/common/cases/loss/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/universal/common/cases/loss/loss.py b/source/tests/universal/common/cases/loss/loss.py new file mode 100644 index 0000000000..a3b585114f --- /dev/null +++ b/source/tests/universal/common/cases/loss/loss.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + + +from .utils import ( + LossTestCase, +) + + +class LossTest(LossTestCase): + def setUp(self) -> None: + LossTestCase.setUp(self) diff --git a/source/tests/universal/common/cases/loss/utils.py b/source/tests/universal/common/cases/loss/utils.py new file mode 100644 index 0000000000..69d027103f --- /dev/null +++ b/source/tests/universal/common/cases/loss/utils.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +import numpy as np + +from deepmd.utils.data import ( + DataRequirementItem, +) + +from .....seed import ( + GLOBAL_SEED, +) + + +class LossTestCase: + """Common test case for loss function.""" + + def setUp(self): + pass + + def test_label_keys(self): + module = self.forward_wrapper(self.module) + label_requirement = self.module.label_requirement + label_dict = {item.key: item for item in label_requirement} + label_keys = sorted(label_dict.keys()) + label_keys_expected = sorted( + [key for key in self.key_to_pref_map if self.key_to_pref_map[key] > 0] + ) + np.testing.assert_equal(label_keys_expected, label_keys) + + def test_forward(self): + module = self.forward_wrapper(self.module) + label_requirement = self.module.label_requirement + label_dict = {item.key: item for item in label_requirement} + label_keys = sorted(label_dict.keys()) + natoms = 5 + + def fake_model(): + model_predict = { + data_key: fake_input_one_frame(label_dict[data_key], natoms) + for data_key in label_keys + } + if "atom_ener" in model_predict: + model_predict["atom_energy"] = model_predict.pop("atom_ener") + return model_predict + + labels = { + data_key: fake_input_one_frame(label_dict[data_key], natoms) + for data_key in label_keys + } + labels.update({"find_" + data_key: 1.0 for data_key in label_keys}) + + _, loss, more_loss = module( + {}, + fake_model, + labels, + natoms, + 1.0, + ) + + +def fake_input_one_frame(data_item: DataRequirementItem, natoms=5) -> np.ndarray: + ndof = data_item.ndof + atomic = data_item.atomic + repeat = data_item.repeat + rng = np.random.default_rng(seed=GLOBAL_SEED) + nframes = 1 + dtype = data_item.dtype if data_item.dtype is not None else np.float64 + if atomic: + ndof = ndof * natoms + data = rng.random([nframes, ndof], dtype) + if repeat != 1: + data = np.repeat(data, repeat).reshape([nframes, -1]) + return data diff --git a/source/tests/universal/pt/backend.py b/source/tests/universal/pt/backend.py index 5146fdc79b..326746b3c1 100644 --- a/source/tests/universal/pt/backend.py +++ b/source/tests/universal/pt/backend.py @@ -83,8 +83,8 @@ def forward_wrapper(self, module, on_cpu=False): def create_wrapper_method(method): def wrapper_method(self, *args, **kwargs): # convert to torch tensor - args = [to_torch_tensor(arg) for arg in args] - kwargs = {k: to_torch_tensor(v) for k, v in kwargs.items()} + args = [_to_torch_tensor(arg) for arg in args] + kwargs = {k: _to_torch_tensor(v) for k, v in kwargs.items()} if on_cpu: args = [ arg.detach().cpu() if arg is not None else None for arg in args @@ -97,11 +97,11 @@ def wrapper_method(self, *args, **kwargs): output = method(*args, **kwargs) # convert to numpy array if isinstance(output, tuple): - output = tuple(to_numpy_array(o) for o in output) + output = tuple(_to_numpy_array(o) for o in output) elif isinstance(output, dict): - output = {k: to_numpy_array(v) for k, v in output.items()} + output = {k: _to_numpy_array(v) for k, v in output.items()} else: - output = to_numpy_array(output) + output = _to_numpy_array(output) return output return wrapper_method @@ -112,3 +112,37 @@ class wrapper_module: forward_lower = create_wrapper_method(module.forward_lower) return wrapper_module() + + +def _to_torch_tensor(xx): + if isinstance(xx, dict): + return {kk: to_torch_tensor(xx[kk]) for kk in xx} + elif callable(xx): + return convert_to_torch_callable(xx) + else: + return to_torch_tensor(xx) + + +def convert_to_torch_callable(func): + def wrapper(*args, **kwargs): + output = _to_torch_tensor(func(*args, **kwargs)) + return output + + return wrapper + + +def _to_numpy_array(xx): + if isinstance(xx, dict): + return {kk: to_numpy_array(xx[kk]) for kk in xx} + elif callable(xx): + return convert_to_numpy_callable(xx) + else: + return to_numpy_array(xx) + + +def convert_to_numpy_callable(func): + def wrapper(*args, **kwargs): + output = _to_numpy_array(func(*args, **kwargs)) + return output + + return wrapper diff --git a/source/tests/universal/pt/loss/__init__.py b/source/tests/universal/pt/loss/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/universal/pt/loss/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/universal/pt/loss/test_loss.py b/source/tests/universal/pt/loss/test_loss.py new file mode 100644 index 0000000000..0ebdbac52f --- /dev/null +++ b/source/tests/universal/pt/loss/test_loss.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from collections import ( + OrderedDict, +) + +from deepmd.pt.loss import ( + EnergyStdLoss, +) + +from ....consistent.common import ( + parameterize_func, + parameterized, +) +from ...common.cases.loss.loss import ( + LossTest, +) + +# from ...dpmodel.fitting.test_fitting import ( +# FittingParamDipole, +# FittingParamDos, +# FittingParamEnergy, +# FittingParamPolar, +# FittingParamProperty, +# ) +from ..backend import ( + PTTestCase, +) + + +def LossParamEnergy( + starter_learning_rate=1.0, + pref_e=1.0, + pref_f=1.0, + pref_v=1.0, + pref_ae=1.0, +): + key_to_pref_map = { + "energy": pref_e, + "force": pref_f, + "virial": pref_v, + "atom_ener": pref_ae, + } + input_dict = { + "key_to_pref_map": key_to_pref_map, + "starter_learning_rate": starter_learning_rate, + "start_pref_e": pref_e, + "limit_pref_e": pref_e / 2, + "start_pref_f": pref_f, + "limit_pref_f": pref_f / 2, + "start_pref_v": pref_v, + "limit_pref_v": pref_v / 2, + "start_pref_ae": pref_ae, + "limit_pref_ae": pref_ae / 2, + } + return input_dict + + +LossParamEnergyList = parameterize_func( + LossParamEnergy, + OrderedDict( + { + "pref_e": (1.0, 0.0), + "pref_f": (1.0, 0.0), + "pref_v": (1.0, 0.0), + "pref_ae": (1.0, 0.0), + } + ), +) +# to get name for the default function +LossParamEnergy = LossParamEnergyList[0] + + +@parameterized( + ( + *[(param_func, EnergyStdLoss) for param_func in LossParamEnergyList], + ) # class_param & class +) +class TestFittingPT(unittest.TestCase, LossTest, PTTestCase): + def setUp(self): + (LossParam, Loss) = self.param[0] + LossTest.setUp(self) + self.module_class = Loss + self.input_dict = LossParam() + self.key_to_pref_map = self.input_dict.pop("key_to_pref_map") + self.module = Loss(**self.input_dict) + self.skip_test_jit = True From bb6b06b12955d3bf2b126e11e54f30fe2dd8bbb2 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 13 Nov 2024 22:17:19 +0800 Subject: [PATCH 2/4] add uts for loss --- deepmd/dpmodel/common.py | 4 +- deepmd/pt/loss/dos.py | 8 +- .../universal/common/cases/loss/utils.py | 6 +- .../tests/universal/dpmodel/loss/__init__.py | 1 + .../tests/universal/dpmodel/loss/test_loss.py | 203 ++++++++++++++++++ source/tests/universal/pt/loss/test_loss.py | 71 ++---- 6 files changed, 229 insertions(+), 64 deletions(-) create mode 100644 source/tests/universal/dpmodel/loss/__init__.py create mode 100644 source/tests/universal/dpmodel/loss/test_loss.py diff --git a/deepmd/dpmodel/common.py b/deepmd/dpmodel/common.py index 6e6113b494..121e40978b 100644 --- a/deepmd/dpmodel/common.py +++ b/deepmd/dpmodel/common.py @@ -29,7 +29,7 @@ "double": np.float64, "int32": np.int32, "int64": np.int64, - "bool": bool, + "bool": np.bool_, "default": GLOBAL_NP_FLOAT_PRECISION, # NumPy doesn't have bfloat16 (and doesn't plan to add) # ml_dtypes is a solution, but it seems not supporting np.save/np.load @@ -45,7 +45,7 @@ np.int32: "int32", np.int64: "int64", ml_dtypes.bfloat16: "bfloat16", - bool: "bool", + np.bool_: "bool", } assert set(RESERVED_PRECISON_DICT.keys()) == set(PRECISION_DICT.values()) DEFAULT_PRECISION = "float64" diff --git a/deepmd/pt/loss/dos.py b/deepmd/pt/loss/dos.py index 84513b6bf9..03765f18b9 100644 --- a/deepmd/pt/loss/dos.py +++ b/deepmd/pt/loss/dos.py @@ -151,10 +151,10 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False if self.has_acdf and "atom_dos" in model_pred and "atom_dos" in label: find_local = label.get("find_atom_dos", 0.0) pref_acdf = pref_acdf * find_local - local_tensor_pred_cdf = torch.cusum( + local_tensor_pred_cdf = torch.cumsum( model_pred["atom_dos"].reshape([-1, natoms, self.numb_dos]), dim=-1 ) - local_tensor_label_cdf = torch.cusum( + local_tensor_label_cdf = torch.cumsum( label["atom_dos"].reshape([-1, natoms, self.numb_dos]), dim=-1 ) diff = (local_tensor_pred_cdf - local_tensor_label_cdf).reshape( @@ -199,10 +199,10 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False if self.has_cdf and "dos" in model_pred and "dos" in label: find_global = label.get("find_dos", 0.0) pref_cdf = pref_cdf * find_global - global_tensor_pred_cdf = torch.cusum( + global_tensor_pred_cdf = torch.cumsum( model_pred["dos"].reshape([-1, self.numb_dos]), dim=-1 ) - global_tensor_label_cdf = torch.cusum( + global_tensor_label_cdf = torch.cumsum( label["dos"].reshape([-1, self.numb_dos]), dim=-1 ) diff = global_tensor_pred_cdf - global_tensor_label_cdf diff --git a/source/tests/universal/common/cases/loss/utils.py b/source/tests/universal/common/cases/loss/utils.py index 69d027103f..6b6ffa662e 100644 --- a/source/tests/universal/common/cases/loss/utils.py +++ b/source/tests/universal/common/cases/loss/utils.py @@ -41,6 +41,7 @@ def fake_model(): } if "atom_ener" in model_predict: model_predict["atom_energy"] = model_predict.pop("atom_ener") + model_predict.update({"mask_mag": np.ones([1, natoms, 1], dtype=np.bool_)}) return model_predict labels = { @@ -66,8 +67,9 @@ def fake_input_one_frame(data_item: DataRequirementItem, natoms=5) -> np.ndarray nframes = 1 dtype = data_item.dtype if data_item.dtype is not None else np.float64 if atomic: - ndof = ndof * natoms - data = rng.random([nframes, ndof], dtype) + data = rng.random([nframes, natoms, ndof], dtype) + else: + data = rng.random([nframes, ndof], dtype) if repeat != 1: data = np.repeat(data, repeat).reshape([nframes, -1]) return data diff --git a/source/tests/universal/dpmodel/loss/__init__.py b/source/tests/universal/dpmodel/loss/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/universal/dpmodel/loss/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/universal/dpmodel/loss/test_loss.py b/source/tests/universal/dpmodel/loss/test_loss.py new file mode 100644 index 0000000000..6473c159da --- /dev/null +++ b/source/tests/universal/dpmodel/loss/test_loss.py @@ -0,0 +1,203 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from collections import ( + OrderedDict, +) + +from ....consistent.common import ( + parameterize_func, +) + + +def LossParamEnergy( + starter_learning_rate=1.0, + pref_e=1.0, + pref_f=1.0, + pref_v=1.0, + pref_ae=1.0, +): + key_to_pref_map = { + "energy": pref_e, + "force": pref_f, + "virial": pref_v, + "atom_ener": pref_ae, + } + input_dict = { + "key_to_pref_map": key_to_pref_map, + "starter_learning_rate": starter_learning_rate, + "start_pref_e": pref_e, + "limit_pref_e": pref_e / 2, + "start_pref_f": pref_f, + "limit_pref_f": pref_f / 2, + "start_pref_v": pref_v, + "limit_pref_v": pref_v / 2, + "start_pref_ae": pref_ae, + "limit_pref_ae": pref_ae / 2, + } + return input_dict + + +LossParamEnergyList = parameterize_func( + LossParamEnergy, + OrderedDict( + { + "pref_e": (1.0, 0.0), + "pref_f": (1.0, 0.0), + "pref_v": (1.0, 0.0), + "pref_ae": (1.0, 0.0), + } + ), +) +# to get name for the default function +LossParamEnergy = LossParamEnergyList[0] + + +def LossParamEnergySpin( + starter_learning_rate=1.0, + pref_e=1.0, + pref_fr=1.0, + pref_fm=1.0, + pref_v=1.0, + pref_ae=1.0, +): + key_to_pref_map = { + "energy": pref_e, + "force": pref_fr, + "force_mag": pref_fm, + "virial": pref_v, + "atom_ener": pref_ae, + } + input_dict = { + "key_to_pref_map": key_to_pref_map, + "starter_learning_rate": starter_learning_rate, + "start_pref_e": pref_e, + "limit_pref_e": pref_e / 2, + "start_pref_fr": pref_fr, + "limit_pref_fr": pref_fr / 2, + "start_pref_fm": pref_fm, + "limit_pref_fm": pref_fm / 2, + "start_pref_v": pref_v, + "limit_pref_v": pref_v / 2, + "start_pref_ae": pref_ae, + "limit_pref_ae": pref_ae / 2, + } + return input_dict + + +LossParamEnergySpinList = parameterize_func( + LossParamEnergySpin, + OrderedDict( + { + "pref_e": (1.0, 0.0), + "pref_fr": (1.0, 0.0), + "pref_fm": (1.0, 0.0), + "pref_v": (1.0, 0.0), + "pref_ae": (1.0, 0.0), + } + ), +) +# to get name for the default function +LossParamEnergySpin = LossParamEnergySpinList[0] + + +def LossParamDos( + starter_learning_rate=1.0, + pref_dos=1.0, + pref_ados=1.0, +): + key_to_pref_map = { + "dos": pref_dos, + "atom_dos": pref_ados, + } + input_dict = { + "key_to_pref_map": key_to_pref_map, + "starter_learning_rate": starter_learning_rate, + "numb_dos": 2, + "start_pref_dos": pref_dos, + "limit_pref_dos": pref_dos / 2, + "start_pref_ados": pref_ados, + "limit_pref_ados": pref_ados / 2, + "start_pref_cdf": 0.0, + "limit_pref_cdf": 0.0, + "start_pref_acdf": 0.0, + "limit_pref_acdf": 0.0, + } + return input_dict + + +LossParamDosList = parameterize_func( + LossParamDos, + OrderedDict( + { + "pref_dos": (1.0,), + "pref_ados": (1.0, 0.0), + } + ), +) + parameterize_func( + LossParamDos, + OrderedDict( + { + "pref_dos": (0.0,), + "pref_ados": (1.0,), + } + ), +) + +# to get name for the default function +LossParamDos = LossParamDosList[0] + + +def LossParamTensor( + pref=1.0, + pref_atomic=1.0, +): + tensor_name = "test_tensor" + key_to_pref_map = { + tensor_name: pref, + f"atomic_{tensor_name}": pref_atomic, + } + input_dict = { + "key_to_pref_map": key_to_pref_map, + "tensor_name": tensor_name, + "tensor_size": 2, + "label_name": tensor_name, + "pref": pref, + "pref_atomic": pref_atomic, + } + return input_dict + + +LossParamTensorList = parameterize_func( + LossParamTensor, + OrderedDict( + { + "pref": (1.0,), + "pref_atomic": (1.0, 0.0), + } + ), +) + parameterize_func( + LossParamTensor, + OrderedDict( + { + "pref": (0.0,), + "pref_atomic": (1.0,), + } + ), +) +# to get name for the default function +LossParamTensor = LossParamTensorList[0] + + +def LossParamProperty(): + key_to_pref_map = { + "property": 1.0, + } + input_dict = { + "key_to_pref_map": key_to_pref_map, + "task_dim": 2, + } + return input_dict + + +LossParamPropertyList = [LossParamProperty] +# to get name for the default function +LossParamProperty = LossParamPropertyList[0] diff --git a/source/tests/universal/pt/loss/test_loss.py b/source/tests/universal/pt/loss/test_loss.py index 0ebdbac52f..928b12b669 100644 --- a/source/tests/universal/pt/loss/test_loss.py +++ b/source/tests/universal/pt/loss/test_loss.py @@ -1,80 +1,39 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -# SPDX-License-Identifier: LGPL-3.0-or-later import unittest -from collections import ( - OrderedDict, -) from deepmd.pt.loss import ( + DOSLoss, + EnergySpinLoss, EnergyStdLoss, + PropertyLoss, + TensorLoss, ) from ....consistent.common import ( - parameterize_func, parameterized, ) from ...common.cases.loss.loss import ( LossTest, ) - -# from ...dpmodel.fitting.test_fitting import ( -# FittingParamDipole, -# FittingParamDos, -# FittingParamEnergy, -# FittingParamPolar, -# FittingParamProperty, -# ) +from ...dpmodel.loss.test_loss import ( + LossParamDosList, + LossParamEnergyList, + LossParamEnergySpinList, + LossParamPropertyList, + LossParamTensorList, +) from ..backend import ( PTTestCase, ) -def LossParamEnergy( - starter_learning_rate=1.0, - pref_e=1.0, - pref_f=1.0, - pref_v=1.0, - pref_ae=1.0, -): - key_to_pref_map = { - "energy": pref_e, - "force": pref_f, - "virial": pref_v, - "atom_ener": pref_ae, - } - input_dict = { - "key_to_pref_map": key_to_pref_map, - "starter_learning_rate": starter_learning_rate, - "start_pref_e": pref_e, - "limit_pref_e": pref_e / 2, - "start_pref_f": pref_f, - "limit_pref_f": pref_f / 2, - "start_pref_v": pref_v, - "limit_pref_v": pref_v / 2, - "start_pref_ae": pref_ae, - "limit_pref_ae": pref_ae / 2, - } - return input_dict - - -LossParamEnergyList = parameterize_func( - LossParamEnergy, - OrderedDict( - { - "pref_e": (1.0, 0.0), - "pref_f": (1.0, 0.0), - "pref_v": (1.0, 0.0), - "pref_ae": (1.0, 0.0), - } - ), -) -# to get name for the default function -LossParamEnergy = LossParamEnergyList[0] - - @parameterized( ( *[(param_func, EnergyStdLoss) for param_func in LossParamEnergyList], + *[(param_func, EnergySpinLoss) for param_func in LossParamEnergySpinList], + *[(param_func, DOSLoss) for param_func in LossParamDosList], + *[(param_func, TensorLoss) for param_func in LossParamTensorList], + *[(param_func, PropertyLoss) for param_func in LossParamPropertyList], ) # class_param & class ) class TestFittingPT(unittest.TestCase, LossTest, PTTestCase): From de1f40aaf16f723d34ace708472d32be86f52d0e Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 14 Nov 2024 17:08:37 +0800 Subject: [PATCH 3/4] Update source/tests/universal/pt/loss/test_loss.py Co-authored-by: Jinzhe Zeng Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com> --- source/tests/universal/pt/loss/test_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/universal/pt/loss/test_loss.py b/source/tests/universal/pt/loss/test_loss.py index 928b12b669..47c2d06fbc 100644 --- a/source/tests/universal/pt/loss/test_loss.py +++ b/source/tests/universal/pt/loss/test_loss.py @@ -36,7 +36,7 @@ *[(param_func, PropertyLoss) for param_func in LossParamPropertyList], ) # class_param & class ) -class TestFittingPT(unittest.TestCase, LossTest, PTTestCase): +class TestLossPT(unittest.TestCase, LossTest, PTTestCase): def setUp(self): (LossParam, Loss) = self.param[0] LossTest.setUp(self) From f35e5422e9b36213f5e606f537bec3a4999842a9 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 14 Nov 2024 17:12:33 +0800 Subject: [PATCH 4/4] Update utils.py --- source/tests/universal/common/cases/loss/utils.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/source/tests/universal/common/cases/loss/utils.py b/source/tests/universal/common/cases/loss/utils.py index 6b6ffa662e..63e6e3ed27 100644 --- a/source/tests/universal/common/cases/loss/utils.py +++ b/source/tests/universal/common/cases/loss/utils.py @@ -33,19 +33,24 @@ def test_forward(self): label_dict = {item.key: item for item in label_requirement} label_keys = sorted(label_dict.keys()) natoms = 5 + nframes = 2 def fake_model(): model_predict = { - data_key: fake_input_one_frame(label_dict[data_key], natoms) + data_key: fake_input( + label_dict[data_key], natoms=natoms, nframes=nframes + ) for data_key in label_keys } if "atom_ener" in model_predict: model_predict["atom_energy"] = model_predict.pop("atom_ener") - model_predict.update({"mask_mag": np.ones([1, natoms, 1], dtype=np.bool_)}) + model_predict.update( + {"mask_mag": np.ones([nframes, natoms, 1], dtype=np.bool_)} + ) return model_predict labels = { - data_key: fake_input_one_frame(label_dict[data_key], natoms) + data_key: fake_input(label_dict[data_key], natoms=natoms, nframes=nframes) for data_key in label_keys } labels.update({"find_" + data_key: 1.0 for data_key in label_keys}) @@ -59,12 +64,11 @@ def fake_model(): ) -def fake_input_one_frame(data_item: DataRequirementItem, natoms=5) -> np.ndarray: +def fake_input(data_item: DataRequirementItem, natoms=5, nframes=2) -> np.ndarray: ndof = data_item.ndof atomic = data_item.atomic repeat = data_item.repeat rng = np.random.default_rng(seed=GLOBAL_SEED) - nframes = 1 dtype = data_item.dtype if data_item.dtype is not None else np.float64 if atomic: data = rng.random([nframes, natoms, ndof], dtype)