From 29f1b7773fbe5e1f46d05192dfb32bb093137b96 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 30 Oct 2024 17:58:15 -0400 Subject: [PATCH 1/2] feat(jax/array-api): property fitting Signed-off-by: Jinzhe Zeng --- deepmd/jax/fitting/fitting.py | 11 ++++ .../tests/array_api_strict/fitting/fitting.py | 9 +++ .../tests/consistent/fitting/test_property.py | 62 +++++++++++++++++++ 3 files changed, 82 insertions(+) diff --git a/deepmd/jax/fitting/fitting.py b/deepmd/jax/fitting/fitting.py index cef1f667b3..963c40ed2a 100644 --- a/deepmd/jax/fitting/fitting.py +++ b/deepmd/jax/fitting/fitting.py @@ -5,6 +5,9 @@ from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP +from deepmd.dpmodel.fitting.property_fitting import ( + PropertyFittingNet as PropertyFittingNetDP, +) from deepmd.jax.common import ( ArrayAPIVariable, flax_module, @@ -47,6 +50,14 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +@BaseFitting.register("property") +@flax_module +class PropertyFittingNet(PropertyFittingNetDP): + def __setattr__(self, name: str, value: Any) -> None: + value = setattr_for_general_fitting(name, value) + return super().__setattr__(name, value) + + @BaseFitting.register("dos") @flax_module class DOSFittingNet(DOSFittingNetDP): diff --git a/source/tests/array_api_strict/fitting/fitting.py b/source/tests/array_api_strict/fitting/fitting.py index 8b65320203..bcc4171cb2 100644 --- a/source/tests/array_api_strict/fitting/fitting.py +++ b/source/tests/array_api_strict/fitting/fitting.py @@ -5,6 +5,9 @@ from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP +from deepmd.dpmodel.fitting.property_fitting import ( + PropertyFittingNet as PropertyFittingNetDP, +) from ..common import ( to_array_api_strict_array, @@ -39,6 +42,12 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +class PropertyFittingNet(PropertyFittingNetDP): + def __setattr__(self, name: str, value: Any) -> None: + value = setattr_for_general_fitting(name, value) + return super().__setattr__(name, value) + + class DOSFittingNet(DOSFittingNetDP): def __setattr__(self, name: str, value: Any) -> None: value = setattr_for_general_fitting(name, value) diff --git a/source/tests/consistent/fitting/test_property.py b/source/tests/consistent/fitting/test_property.py index beb21d9c04..4e0fe04f9f 100644 --- a/source/tests/consistent/fitting/test_property.py +++ b/source/tests/consistent/fitting/test_property.py @@ -17,6 +17,8 @@ ) from ..common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, INSTALLED_PT, CommonTest, parameterized, @@ -32,6 +34,22 @@ from deepmd.pt.utils.env import DEVICE as PT_DEVICE else: PropertyFittingPT = object +if INSTALLED_JAX: + from deepmd.jax.env import ( + jnp, + ) + from deepmd.jax.fitting.fitting import PropertyFittingNet as PropertyFittingJAX +else: + PropertyFittingJAX = object +if INSTALLED_ARRAY_API_STRICT: + import array_api_strict + + from ...array_api_strict.fitting.fitting import ( + PropertyFittingNet as PropertyFittingStrict, + ) +else: + PropertyFittingStrict = object + PropertyFittingTF = object @@ -84,9 +102,14 @@ def skip_pt(self) -> bool: def skip_tf(self) -> bool: return True + skip_jax = not INSTALLED_JAX + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT + tf_class = PropertyFittingTF dp_class = PropertyFittingDP pt_class = PropertyFittingPT + jax_class = PropertyFittingJAX + array_api_strict_class = PropertyFittingStrict args = fitting_property() def setUp(self): @@ -183,6 +206,45 @@ def eval_dp(self, dp_obj: Any) -> Any: aparam=self.aparam if numb_aparam else None, )["property"] + def eval_jax(self, jax_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + numb_aparam, + task_dim, + intensive, + ) = self.param + return np.asarray( + jax_obj( + jnp.asarray(self.inputs), + jnp.asarray(self.atype.reshape(1, -1)), + fparam=jnp.asarray(self.fparam) if numb_fparam else None, + aparam=jnp.asarray(self.aparam) if numb_aparam else None, + )["property"] + ) + + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + array_api_strict.set_array_api_strict_flags(api_version="2023.12") + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + numb_aparam, + task_dim, + intensive, + ) = self.param + return np.asarray( + array_api_strict_obj( + array_api_strict.asarray(self.inputs), + array_api_strict.asarray(self.atype.reshape(1, -1)), + fparam=array_api_strict.asarray(self.fparam) if numb_fparam else None, + aparam=array_api_strict.asarray(self.aparam) if numb_aparam else None, + )["property"] + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: if backend == self.RefBackend.TF: # shape is not same From 94d310509742d5ed3b480e67489314427ff5daaf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 1 Nov 2024 01:14:26 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/jax/fitting/fitting.py | 6 +++--- source/tests/array_api_strict/fitting/fitting.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/deepmd/jax/fitting/fitting.py b/deepmd/jax/fitting/fitting.py index 8a9c4cf321..d62681490c 100644 --- a/deepmd/jax/fitting/fitting.py +++ b/deepmd/jax/fitting/fitting.py @@ -6,12 +6,12 @@ from deepmd.dpmodel.fitting.dipole_fitting import DipoleFitting as DipoleFittingNetDP from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP -from deepmd.dpmodel.fitting.property_fitting import ( - PropertyFittingNet as PropertyFittingNetDP, -) from deepmd.dpmodel.fitting.polarizability_fitting import ( PolarFitting as PolarFittingNetDP, ) +from deepmd.dpmodel.fitting.property_fitting import ( + PropertyFittingNet as PropertyFittingNetDP, +) from deepmd.jax.common import ( ArrayAPIVariable, flax_module, diff --git a/source/tests/array_api_strict/fitting/fitting.py b/source/tests/array_api_strict/fitting/fitting.py index 8103e4c120..323a49cfe8 100644 --- a/source/tests/array_api_strict/fitting/fitting.py +++ b/source/tests/array_api_strict/fitting/fitting.py @@ -6,12 +6,12 @@ from deepmd.dpmodel.fitting.dipole_fitting import DipoleFitting as DipoleFittingNetDP from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP -from deepmd.dpmodel.fitting.property_fitting import ( - PropertyFittingNet as PropertyFittingNetDP, -) from deepmd.dpmodel.fitting.polarizability_fitting import ( PolarFitting as PolarFittingNetDP, ) +from deepmd.dpmodel.fitting.property_fitting import ( + PropertyFittingNet as PropertyFittingNetDP, +) from ..common import ( to_array_api_strict_array,