From 5050f611133665580fb44cd62cbe6d84d4864ac8 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 15 Oct 2024 23:40:01 -0400 Subject: [PATCH] feat(jax/array-api): DOS fitting (#4218) ## Summary by CodeRabbit - **New Features** - Introduced the `DOSFittingNet` class for enhanced fitting capabilities. - Added methods to evaluate different backends (JAX and Array API Strict) for computing density of states. - Enhanced testing framework to conditionally include tests based on library availability. - **Bug Fixes** - Improved serialization of the `bias_atom_e` variable to ensure consistent data representation. - **Tests** - Expanded the `TestDOS` class with new attributes and methods for better backend evaluation. --------- Signed-off-by: Jinzhe Zeng Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- deepmd/dpmodel/fitting/dos_fitting.py | 3 +- deepmd/jax/fitting/fitting.py | 8 +++ .../tests/array_api_strict/fitting/fitting.py | 7 +++ source/tests/consistent/fitting/test_dos.py | 59 +++++++++++++++++++ 4 files changed, 76 insertions(+), 1 deletion(-) diff --git a/deepmd/dpmodel/fitting/dos_fitting.py b/deepmd/dpmodel/fitting/dos_fitting.py index e9cd4a17ae..32225ac6c0 100644 --- a/deepmd/dpmodel/fitting/dos_fitting.py +++ b/deepmd/dpmodel/fitting/dos_fitting.py @@ -10,6 +10,7 @@ from deepmd.dpmodel.common import ( DEFAULT_PRECISION, + to_numpy_array, ) from deepmd.dpmodel.fitting.invar_fitting import ( InvarFitting, @@ -89,6 +90,6 @@ def serialize(self) -> dict: **super().serialize(), "type": "dos", } - dd["@variables"]["bias_atom_e"] = self.bias_atom_e + dd["@variables"]["bias_atom_e"] = to_numpy_array(self.bias_atom_e) return dd diff --git a/deepmd/jax/fitting/fitting.py b/deepmd/jax/fitting/fitting.py index 27ad791db9..284213c70a 100644 --- a/deepmd/jax/fitting/fitting.py +++ b/deepmd/jax/fitting/fitting.py @@ -3,6 +3,7 @@ Any, ) +from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP from deepmd.jax.common import ( flax_module, @@ -37,3 +38,10 @@ class EnergyFittingNet(EnergyFittingNetDP): def __setattr__(self, name: str, value: Any) -> None: value = setattr_for_general_fitting(name, value) return super().__setattr__(name, value) + + +@flax_module +class DOSFittingNet(DOSFittingNetDP): + def __setattr__(self, name: str, value: Any) -> None: + value = setattr_for_general_fitting(name, value) + return super().__setattr__(name, value) diff --git a/source/tests/array_api_strict/fitting/fitting.py b/source/tests/array_api_strict/fitting/fitting.py index 2e6bd9fe25..8b65320203 100644 --- a/source/tests/array_api_strict/fitting/fitting.py +++ b/source/tests/array_api_strict/fitting/fitting.py @@ -3,6 +3,7 @@ Any, ) +from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP from ..common import ( @@ -36,3 +37,9 @@ class EnergyFittingNet(EnergyFittingNetDP): def __setattr__(self, name: str, value: Any) -> None: value = setattr_for_general_fitting(name, value) return super().__setattr__(name, value) + + +class DOSFittingNet(DOSFittingNetDP): + def __setattr__(self, name: str, value: Any) -> None: + value = setattr_for_general_fitting(name, value) + return super().__setattr__(name, value) diff --git a/source/tests/consistent/fitting/test_dos.py b/source/tests/consistent/fitting/test_dos.py index ada65c8ac5..4a78b69341 100644 --- a/source/tests/consistent/fitting/test_dos.py +++ b/source/tests/consistent/fitting/test_dos.py @@ -12,6 +12,8 @@ ) from ..common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, CommonTest, @@ -36,6 +38,20 @@ fitting_dos, ) +if INSTALLED_JAX: + from deepmd.jax.env import ( + jnp, + ) + from deepmd.jax.fitting.fitting import DOSFittingNet as DOSFittingJAX +else: + DOSFittingJAX = object +if INSTALLED_ARRAY_API_STRICT: + import array_api_strict + + from ...array_api_strict.fitting.fitting import DOSFittingNet as DOSFittingStrict +else: + DOSFittingStrict = object + @parameterized( (True, False), # resnet_dt @@ -74,9 +90,19 @@ def skip_pt(self) -> bool: ) = self.param return CommonTest.skip_pt + @property + def skip_jax(self) -> bool: + return not INSTALLED_JAX + + @property + def skip_array_api_strict(self) -> bool: + return not INSTALLED_ARRAY_API_STRICT + tf_class = DOSFittingTF dp_class = DOSFittingDP pt_class = DOSFittingPT + jax_class = DOSFittingJAX + array_api_strict_class = DOSFittingStrict args = fitting_dos() def setUp(self): @@ -157,6 +183,39 @@ def eval_dp(self, dp_obj: Any) -> Any: fparam=self.fparam if numb_fparam else None, )["dos"] + def eval_jax(self, jax_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + numb_dos, + ) = self.param + return np.asarray( + jax_obj( + jnp.asarray(self.inputs), + jnp.asarray(self.atype.reshape(1, -1)), + fparam=jnp.asarray(self.fparam) if numb_fparam else None, + )["dos"] + ) + + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + array_api_strict.set_array_api_strict_flags(api_version="2023.12") + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + numb_dos, + ) = self.param + return np.asarray( + array_api_strict_obj( + array_api_strict.asarray(self.inputs), + array_api_strict.asarray(self.atype.reshape(1, -1)), + fparam=array_api_strict.asarray(self.fparam) if numb_fparam else None, + )["dos"] + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: if backend == self.RefBackend.TF: # shape is not same