From 2b4898f988293e2d18cbf15c0da248ae3c57dd33 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 30 Oct 2024 02:48:52 +0000 Subject: [PATCH] feat(jax/array-api): dipole/polarizability fitting Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/fitting/dipole_fitting.py | 10 ++- .../dpmodel/fitting/polarizability_fitting.py | 71 +++++++++++-------- deepmd/jax/fitting/__init__.py | 4 ++ deepmd/jax/fitting/fitting.py | 27 +++++++ doc/model/train-fitting-tensor.md | 4 +- .../tests/array_api_strict/fitting/fitting.py | 21 ++++++ .../tests/consistent/fitting/test_dipole.py | 41 +++++++++++ source/tests/consistent/fitting/test_polar.py | 41 +++++++++++ 8 files changed, 184 insertions(+), 35 deletions(-) diff --git a/deepmd/dpmodel/fitting/dipole_fitting.py b/deepmd/dpmodel/fitting/dipole_fitting.py index 01bd60c777..cecba865d0 100644 --- a/deepmd/dpmodel/fitting/dipole_fitting.py +++ b/deepmd/dpmodel/fitting/dipole_fitting.py @@ -6,6 +6,7 @@ Union, ) +import array_api_compat import numpy as np from deepmd.dpmodel import ( @@ -207,6 +208,7 @@ def call( The atomic parameter. shape: nf x nloc x nap. nap being `numb_aparam` """ + xp = array_api_compat.array_namespace(descriptor, atype) nframes, nloc, _ = descriptor.shape assert gr is not None, "Must provide the rotation matrix for dipole fitting." # (nframes, nloc, m1) @@ -214,9 +216,11 @@ def call( self.var_name ] # (nframes * nloc, 1, m1) - out = out.reshape(-1, 1, self.embedding_width) + out = xp.reshape(out, (-1, 1, self.embedding_width)) # (nframes * nloc, m1, 3) - gr = gr.reshape(nframes * nloc, -1, 3) + gr = xp.reshape(gr, (nframes * nloc, -1, 3)) # (nframes, nloc, 3) - out = np.einsum("bim,bmj->bij", out, gr).squeeze(-2).reshape(nframes, nloc, 3) + # out = np.einsum("bim,bmj->bij", out, gr).squeeze(-2).reshape(nframes, nloc, 3) + out = out @ gr + out = xp.reshape(out, (nframes, nloc, 3)) return {self.var_name: out} diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 2d96eec580..b972b45971 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -6,6 +6,7 @@ Union, ) +import array_api_compat import numpy as np from deepmd.common import ( @@ -14,6 +15,9 @@ from deepmd.dpmodel import ( DEFAULT_PRECISION, ) +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.fitting.base_fitting import ( BaseFitting, ) @@ -124,23 +128,18 @@ def __init__( self.embedding_width = embedding_width self.fit_diag = fit_diag - self.scale = scale - if self.scale is None: - self.scale = [1.0 for _ in range(ntypes)] + if scale is None: + scale = [1.0 for _ in range(ntypes)] else: - if isinstance(self.scale, list): - assert ( - len(self.scale) == ntypes - ), "Scale should be a list of length ntypes." - elif isinstance(self.scale, float): - self.scale = [self.scale for _ in range(ntypes)] + if isinstance(scale, list): + assert len(scale) == ntypes, "Scale should be a list of length ntypes." + elif isinstance(scale, float): + scale = [scale for _ in range(ntypes)] else: raise ValueError( "Scale must be a list of float of length ntypes or a float." ) - self.scale = np.array(self.scale, dtype=GLOBAL_NP_FLOAT_PRECISION).reshape( - ntypes, 1 - ) + self.scale = np.array(scale, dtype=GLOBAL_NP_FLOAT_PRECISION).reshape(ntypes, 1) self.shift_diag = shift_diag self.constant_matrix = np.zeros(ntypes, dtype=GLOBAL_NP_FLOAT_PRECISION) super().__init__( @@ -192,8 +191,8 @@ def serialize(self) -> dict: data["embedding_width"] = self.embedding_width data["fit_diag"] = self.fit_diag data["shift_diag"] = self.shift_diag - data["@variables"]["scale"] = self.scale - data["@variables"]["constant_matrix"] = self.constant_matrix + data["@variables"]["scale"] = to_numpy_array(self.scale) + data["@variables"]["constant_matrix"] = to_numpy_array(self.constant_matrix) return data @classmethod @@ -276,6 +275,7 @@ def call( The atomic parameter. shape: nf x nloc x nap. nap being `numb_aparam` """ + xp = array_api_compat.array_namespace(descriptor, atype) nframes, nloc, _ = descriptor.shape assert ( gr is not None @@ -284,28 +284,39 @@ def call( out = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ self.var_name ] - out = out * self.scale[atype] + # out = out * self.scale[atype, ...] + scale_atype = xp.reshape( + xp.take(self.scale, xp.reshape(atype, [-1]), axis=0), (*atype.shape, 1) + ) + out = out * scale_atype # (nframes * nloc, m1, 3) - gr = gr.reshape(nframes * nloc, -1, 3) + gr = xp.reshape(gr, (nframes * nloc, -1, 3)) if self.fit_diag: - out = out.reshape(-1, self.embedding_width) - out = np.einsum("ij,ijk->ijk", out, gr) + out = xp.reshape(out, (-1, self.embedding_width)) + # out = np.einsum("ij,ijk->ijk", out, gr) + out = out[:, :, None] * gr else: - out = out.reshape(-1, self.embedding_width, self.embedding_width) - out = (out + np.transpose(out, axes=(0, 2, 1))) / 2 - out = np.einsum("bim,bmj->bij", out, gr) # (nframes * nloc, m1, 3) - out = np.einsum( - "bim,bmj->bij", np.transpose(gr, axes=(0, 2, 1)), out - ) # (nframes * nloc, 3, 3) - out = out.reshape(nframes, nloc, 3, 3) + out = xp.reshape(out, (-1, self.embedding_width, self.embedding_width)) + out = (out + xp.matrix_transpose(out)) / 2 + # out = np.einsum("bim,bmj->bij", out, gr) # (nframes * nloc, m1, 3) + out = out @ gr + # out = np.einsum( + # "bim,bmj->bij", np.transpose(gr, axes=(0, 2, 1)), out + # ) # (nframes * nloc, 3, 3) + out = xp.matrix_transpose(gr) @ out + out = xp.reshape(out, (nframes, nloc, 3, 3)) if self.shift_diag: - bias = self.constant_matrix[atype] + # bias = self.constant_matrix[atype] + bias = xp.reshape( + xp.take(self.constant_matrix, xp.reshape(atype, [-1]), axis=0), + (nframes, nloc), + ) # (nframes, nloc, 1) - bias = np.expand_dims(bias, axis=-1) * self.scale[atype] - eye = np.eye(3, dtype=descriptor.dtype) - eye = np.tile(eye, (nframes, nloc, 1, 1)) + bias = bias[..., None] * scale_atype + eye = xp.eye(3, dtype=descriptor.dtype) + eye = xp.tile(eye, (nframes, nloc, 1, 1)) # (nframes, nloc, 3, 3) - bias = np.expand_dims(bias, axis=-1) * eye + bias = bias[..., None] * eye out = out + bias return {"polarizability": out} diff --git a/deepmd/jax/fitting/__init__.py b/deepmd/jax/fitting/__init__.py index e72314dcab..226a6d5b43 100644 --- a/deepmd/jax/fitting/__init__.py +++ b/deepmd/jax/fitting/__init__.py @@ -1,10 +1,14 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from deepmd.jax.fitting.fitting import ( + DipoleFittingNet, DOSFittingNet, EnergyFittingNet, + PolarFittingNet, ) __all__ = [ "EnergyFittingNet", "DOSFittingNet", + "DipoleFittingNet", + "PolarFittingNet", ] diff --git a/deepmd/jax/fitting/fitting.py b/deepmd/jax/fitting/fitting.py index cef1f667b3..2a6186ac46 100644 --- a/deepmd/jax/fitting/fitting.py +++ b/deepmd/jax/fitting/fitting.py @@ -3,8 +3,12 @@ Any, ) +from deepmd.dpmodel.fitting.dipole_fitting import DipoleFitting as DipoleFittingNetDP from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP +from deepmd.dpmodel.fitting.polarizability_fitting import ( + PolarFitting as PolarFittingNetDP, +) from deepmd.jax.common import ( ArrayAPIVariable, flax_module, @@ -53,3 +57,26 @@ class DOSFittingNet(DOSFittingNetDP): def __setattr__(self, name: str, value: Any) -> None: value = setattr_for_general_fitting(name, value) return super().__setattr__(name, value) + + +@BaseFitting.register("dipole") +@flax_module +class DipoleFittingNet(DipoleFittingNetDP): + def __setattr__(self, name: str, value: Any) -> None: + value = setattr_for_general_fitting(name, value) + return super().__setattr__(name, value) + + +@BaseFitting.register("polar") +@flax_module +class PolarFittingNet(PolarFittingNetDP): + def __setattr__(self, name: str, value: Any) -> None: + value = setattr_for_general_fitting(name, value) + if name in { + "scale", + "constant_matrix", + }: + value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) + return super().__setattr__(name, value) diff --git a/doc/model/train-fitting-tensor.md b/doc/model/train-fitting-tensor.md index c6b54c69ef..d4d546eccf 100644 --- a/doc/model/train-fitting-tensor.md +++ b/doc/model/train-fitting-tensor.md @@ -1,7 +1,7 @@ -# Fit `tensor` like `Dipole` and `Polarizability` {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }} +# Fit `tensor` like `Dipole` and `Polarizability` {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }} :::{note} -**Supported backends**: TensorFlow {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }} +**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }} ::: Unlike `energy`, which is a scalar, one may want to fit some high dimensional physical quantity, like `dipole` (vector) and `polarizability` (matrix, shorted as `polar`). Deep Potential has provided different APIs to do this. In this example, we will show you how to train a model to fit a water system. A complete training input script of the examples can be found in diff --git a/source/tests/array_api_strict/fitting/fitting.py b/source/tests/array_api_strict/fitting/fitting.py index 8b65320203..5a2bd9c58f 100644 --- a/source/tests/array_api_strict/fitting/fitting.py +++ b/source/tests/array_api_strict/fitting/fitting.py @@ -3,8 +3,12 @@ Any, ) +from deepmd.dpmodel.fitting.dipole_fitting import DipoleFitting as DipoleFittingNetDP from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP +from deepmd.dpmodel.fitting.polarizability_fitting import ( + PolarFitting as PolarFittingNetDP, +) from ..common import ( to_array_api_strict_array, @@ -43,3 +47,20 @@ class DOSFittingNet(DOSFittingNetDP): def __setattr__(self, name: str, value: Any) -> None: value = setattr_for_general_fitting(name, value) return super().__setattr__(name, value) + + +class DipoleFittingNet(DipoleFittingNetDP): + def __setattr__(self, name: str, value: Any) -> None: + value = setattr_for_general_fitting(name, value) + return super().__setattr__(name, value) + + +class PolarFittingNet(PolarFittingNetDP): + def __setattr__(self, name: str, value: Any) -> None: + value = setattr_for_general_fitting(name, value) + if name in { + "scale", + "constant_matrix", + }: + value = to_array_api_strict_array(value) + return super().__setattr__(name, value) diff --git a/source/tests/consistent/fitting/test_dipole.py b/source/tests/consistent/fitting/test_dipole.py index 5d7be1b0e5..55d6c44c34 100644 --- a/source/tests/consistent/fitting/test_dipole.py +++ b/source/tests/consistent/fitting/test_dipole.py @@ -12,6 +12,8 @@ ) from ..common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, CommonTest, @@ -32,6 +34,21 @@ from deepmd.tf.fit.dipole import DipoleFittingSeA as DipoleFittingTF else: DipoleFittingTF = object +if INSTALLED_JAX: + from deepmd.jax.env import ( + jnp, + ) + from deepmd.jax.fitting.fitting import DipoleFittingNet as DipoleFittingJAX +else: + DipoleFittingJAX = object +if INSTALLED_ARRAY_API_STRICT: + import array_api_strict + + from ...array_api_strict.fitting.fitting import ( + DipoleFittingNet as DipoleFittingArrayAPIStrict, + ) +else: + DipoleFittingArrayAPIStrict = object from deepmd.utils.argcheck import ( fitting_dipole, ) @@ -69,7 +86,11 @@ def skip_pt(self) -> bool: tf_class = DipoleFittingTF dp_class = DipoleFittingDP pt_class = DipoleFittingPT + jax_class = DipoleFittingJAX + array_api_strict_class = DipoleFittingArrayAPIStrict args = fitting_dipole() + skip_jax = not INSTALLED_JAX + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT def setUp(self): CommonTest.setUp(self) @@ -143,6 +164,26 @@ def eval_dp(self, dp_obj: Any) -> Any: None, )["dipole"] + def eval_jax(self, jax_obj: Any) -> Any: + return np.asarray( + jax_obj( + jnp.asarray(self.inputs), + jnp.asarray(self.atype.reshape(1, -1)), + jnp.asarray(self.gr), + None, + )["dipole"] + ) + + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + return np.asarray( + array_api_strict_obj( + array_api_strict.asarray(self.inputs), + array_api_strict.asarray(self.atype.reshape(1, -1)), + array_api_strict.asarray(self.gr), + None, + )["dipole"] + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: if backend == self.RefBackend.TF: # shape is not same diff --git a/source/tests/consistent/fitting/test_polar.py b/source/tests/consistent/fitting/test_polar.py index 6a3465ba24..895974baf9 100644 --- a/source/tests/consistent/fitting/test_polar.py +++ b/source/tests/consistent/fitting/test_polar.py @@ -12,6 +12,8 @@ ) from ..common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, CommonTest, @@ -32,6 +34,21 @@ from deepmd.tf.fit.polar import PolarFittingSeA as PolarFittingTF else: PolarFittingTF = object +if INSTALLED_JAX: + from deepmd.jax.env import ( + jnp, + ) + from deepmd.jax.fitting.fitting import PolarFittingNet as PolarFittingJAX +else: + PolarFittingJAX = object +if INSTALLED_ARRAY_API_STRICT: + import array_api_strict + + from ...array_api_strict.fitting.fitting import ( + PolarFittingNet as PolarFittingArrayAPIStrict, + ) +else: + PolarFittingArrayAPIStrict = object from deepmd.utils.argcheck import ( fitting_polar, ) @@ -69,7 +86,11 @@ def skip_pt(self) -> bool: tf_class = PolarFittingTF dp_class = PolarFittingDP pt_class = PolarFittingPT + jax_class = PolarFittingJAX + array_api_strict_class = PolarFittingArrayAPIStrict args = fitting_polar() + skip_jax = not INSTALLED_JAX + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT def setUp(self): CommonTest.setUp(self) @@ -143,6 +164,26 @@ def eval_dp(self, dp_obj: Any) -> Any: None, )["polarizability"] + def eval_jax(self, jax_obj: Any) -> Any: + return np.asarray( + jax_obj( + jnp.asarray(self.inputs), + jnp.asarray(self.atype.reshape(1, -1)), + jnp.asarray(self.gr), + None, + )["polarizability"] + ) + + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + return np.asarray( + array_api_strict_obj( + array_api_strict.asarray(self.inputs), + array_api_strict.asarray(self.atype.reshape(1, -1)), + array_api_strict.asarray(self.gr), + None, + )["polarizability"] + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: if backend == self.RefBackend.TF: # shape is not same