Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax/array-api): dipole/polarizability fitting #4278

Merged
merged 1 commit into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions deepmd/dpmodel/fitting/dipole_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel import (
Expand Down Expand Up @@ -207,16 +208,19 @@ def call(
The atomic parameter. shape: nf x nloc x nap. nap being `numb_aparam`

"""
xp = array_api_compat.array_namespace(descriptor, atype)
nframes, nloc, _ = descriptor.shape
assert gr is not None, "Must provide the rotation matrix for dipole fitting."
# (nframes, nloc, m1)
out = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam)[
self.var_name
]
# (nframes * nloc, 1, m1)
out = out.reshape(-1, 1, self.embedding_width)
out = xp.reshape(out, (-1, 1, self.embedding_width))
# (nframes * nloc, m1, 3)
gr = gr.reshape(nframes * nloc, -1, 3)
gr = xp.reshape(gr, (nframes * nloc, -1, 3))
# (nframes, nloc, 3)
out = np.einsum("bim,bmj->bij", out, gr).squeeze(-2).reshape(nframes, nloc, 3)
# out = np.einsum("bim,bmj->bij", out, gr).squeeze(-2).reshape(nframes, nloc, 3)
out = out @ gr
out = xp.reshape(out, (nframes, nloc, 3))
return {self.var_name: out}
71 changes: 41 additions & 30 deletions deepmd/dpmodel/fitting/polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.common import (
Expand All @@ -14,6 +15,9 @@
from deepmd.dpmodel import (
DEFAULT_PRECISION,
)
from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.fitting.base_fitting import (
BaseFitting,
)
Expand Down Expand Up @@ -124,23 +128,18 @@

self.embedding_width = embedding_width
self.fit_diag = fit_diag
self.scale = scale
if self.scale is None:
self.scale = [1.0 for _ in range(ntypes)]
if scale is None:
scale = [1.0 for _ in range(ntypes)]
else:
if isinstance(self.scale, list):
assert (
len(self.scale) == ntypes
), "Scale should be a list of length ntypes."
elif isinstance(self.scale, float):
self.scale = [self.scale for _ in range(ntypes)]
if isinstance(scale, list):
assert len(scale) == ntypes, "Scale should be a list of length ntypes."

Check warning on line 135 in deepmd/dpmodel/fitting/polarizability_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/polarizability_fitting.py#L135

Added line #L135 was not covered by tests
elif isinstance(scale, float):
scale = [scale for _ in range(ntypes)]
else:
raise ValueError(
"Scale must be a list of float of length ntypes or a float."
)
self.scale = np.array(self.scale, dtype=GLOBAL_NP_FLOAT_PRECISION).reshape(
ntypes, 1
)
self.scale = np.array(scale, dtype=GLOBAL_NP_FLOAT_PRECISION).reshape(ntypes, 1)
self.shift_diag = shift_diag
self.constant_matrix = np.zeros(ntypes, dtype=GLOBAL_NP_FLOAT_PRECISION)
super().__init__(
Expand Down Expand Up @@ -192,8 +191,8 @@
data["embedding_width"] = self.embedding_width
data["fit_diag"] = self.fit_diag
data["shift_diag"] = self.shift_diag
data["@variables"]["scale"] = self.scale
data["@variables"]["constant_matrix"] = self.constant_matrix
data["@variables"]["scale"] = to_numpy_array(self.scale)
data["@variables"]["constant_matrix"] = to_numpy_array(self.constant_matrix)
return data

@classmethod
Expand Down Expand Up @@ -276,6 +275,7 @@
The atomic parameter. shape: nf x nloc x nap. nap being `numb_aparam`

"""
xp = array_api_compat.array_namespace(descriptor, atype)
nframes, nloc, _ = descriptor.shape
assert (
gr is not None
Expand All @@ -284,28 +284,39 @@
out = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam)[
self.var_name
]
out = out * self.scale[atype]
# out = out * self.scale[atype, ...]
scale_atype = xp.reshape(
xp.take(self.scale, xp.reshape(atype, [-1]), axis=0), (*atype.shape, 1)
)
njzjz marked this conversation as resolved.
Show resolved Hide resolved
out = out * scale_atype
# (nframes * nloc, m1, 3)
gr = gr.reshape(nframes * nloc, -1, 3)
gr = xp.reshape(gr, (nframes * nloc, -1, 3))

if self.fit_diag:
out = out.reshape(-1, self.embedding_width)
out = np.einsum("ij,ijk->ijk", out, gr)
out = xp.reshape(out, (-1, self.embedding_width))
# out = np.einsum("ij,ijk->ijk", out, gr)
out = out[:, :, None] * gr
else:
out = out.reshape(-1, self.embedding_width, self.embedding_width)
out = (out + np.transpose(out, axes=(0, 2, 1))) / 2
out = np.einsum("bim,bmj->bij", out, gr) # (nframes * nloc, m1, 3)
out = np.einsum(
"bim,bmj->bij", np.transpose(gr, axes=(0, 2, 1)), out
) # (nframes * nloc, 3, 3)
out = out.reshape(nframes, nloc, 3, 3)
out = xp.reshape(out, (-1, self.embedding_width, self.embedding_width))
out = (out + xp.matrix_transpose(out)) / 2
# out = np.einsum("bim,bmj->bij", out, gr) # (nframes * nloc, m1, 3)
out = out @ gr
# out = np.einsum(
# "bim,bmj->bij", np.transpose(gr, axes=(0, 2, 1)), out
# ) # (nframes * nloc, 3, 3)
out = xp.matrix_transpose(gr) @ out
out = xp.reshape(out, (nframes, nloc, 3, 3))
if self.shift_diag:
bias = self.constant_matrix[atype]
# bias = self.constant_matrix[atype]
bias = xp.reshape(
xp.take(self.constant_matrix, xp.reshape(atype, [-1]), axis=0),
(nframes, nloc),
)
njzjz marked this conversation as resolved.
Show resolved Hide resolved
# (nframes, nloc, 1)
bias = np.expand_dims(bias, axis=-1) * self.scale[atype]
eye = np.eye(3, dtype=descriptor.dtype)
eye = np.tile(eye, (nframes, nloc, 1, 1))
bias = bias[..., None] * scale_atype
eye = xp.eye(3, dtype=descriptor.dtype)
eye = xp.tile(eye, (nframes, nloc, 1, 1))
# (nframes, nloc, 3, 3)
bias = np.expand_dims(bias, axis=-1) * eye
bias = bias[..., None] * eye
out = out + bias
return {"polarizability": out}
4 changes: 4 additions & 0 deletions deepmd/jax/fitting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.jax.fitting.fitting import (
DipoleFittingNet,
DOSFittingNet,
EnergyFittingNet,
PolarFittingNet,
)

__all__ = [
"EnergyFittingNet",
"DOSFittingNet",
"DipoleFittingNet",
"PolarFittingNet",
]
27 changes: 27 additions & 0 deletions deepmd/jax/fitting/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
Any,
)

from deepmd.dpmodel.fitting.dipole_fitting import DipoleFitting as DipoleFittingNetDP
from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP
from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP
from deepmd.dpmodel.fitting.polarizability_fitting import (
PolarFitting as PolarFittingNetDP,
)
from deepmd.jax.common import (
ArrayAPIVariable,
flax_module,
Expand Down Expand Up @@ -53,3 +57,26 @@ class DOSFittingNet(DOSFittingNetDP):
def __setattr__(self, name: str, value: Any) -> None:
value = setattr_for_general_fitting(name, value)
return super().__setattr__(name, value)


@BaseFitting.register("dipole")
@flax_module
class DipoleFittingNet(DipoleFittingNetDP):
def __setattr__(self, name: str, value: Any) -> None:
value = setattr_for_general_fitting(name, value)
return super().__setattr__(name, value)


@BaseFitting.register("polar")
@flax_module
class PolarFittingNet(PolarFittingNetDP):
def __setattr__(self, name: str, value: Any) -> None:
value = setattr_for_general_fitting(name, value)
if name in {
"scale",
"constant_matrix",
}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIVariable(value)
return super().__setattr__(name, value)
njzjz marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 2 additions & 2 deletions doc/model/train-fitting-tensor.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Fit `tensor` like `Dipole` and `Polarizability` {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
# Fit `tensor` like `Dipole` and `Polarizability` {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
:::
njzjz marked this conversation as resolved.
Show resolved Hide resolved

Unlike `energy`, which is a scalar, one may want to fit some high dimensional physical quantity, like `dipole` (vector) and `polarizability` (matrix, shorted as `polar`). Deep Potential has provided different APIs to do this. In this example, we will show you how to train a model to fit a water system. A complete training input script of the examples can be found in
Expand Down
21 changes: 21 additions & 0 deletions source/tests/array_api_strict/fitting/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
Any,
)

from deepmd.dpmodel.fitting.dipole_fitting import DipoleFitting as DipoleFittingNetDP
from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP
from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP
from deepmd.dpmodel.fitting.polarizability_fitting import (
PolarFitting as PolarFittingNetDP,
)

from ..common import (
to_array_api_strict_array,
Expand Down Expand Up @@ -43,3 +47,20 @@ class DOSFittingNet(DOSFittingNetDP):
def __setattr__(self, name: str, value: Any) -> None:
value = setattr_for_general_fitting(name, value)
return super().__setattr__(name, value)


class DipoleFittingNet(DipoleFittingNetDP):
def __setattr__(self, name: str, value: Any) -> None:
value = setattr_for_general_fitting(name, value)
return super().__setattr__(name, value)


class PolarFittingNet(PolarFittingNetDP):
def __setattr__(self, name: str, value: Any) -> None:
value = setattr_for_general_fitting(name, value)
if name in {
"scale",
"constant_matrix",
}:
value = to_array_api_strict_array(value)
return super().__setattr__(name, value)
njzjz marked this conversation as resolved.
Show resolved Hide resolved
41 changes: 41 additions & 0 deletions source/tests/consistent/fitting/test_dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
)

from ..common import (
INSTALLED_ARRAY_API_STRICT,
INSTALLED_JAX,
INSTALLED_PT,
INSTALLED_TF,
CommonTest,
Expand All @@ -32,6 +34,21 @@
from deepmd.tf.fit.dipole import DipoleFittingSeA as DipoleFittingTF
else:
DipoleFittingTF = object
if INSTALLED_JAX:
from deepmd.jax.env import (
jnp,
)
from deepmd.jax.fitting.fitting import DipoleFittingNet as DipoleFittingJAX
else:
DipoleFittingJAX = object
if INSTALLED_ARRAY_API_STRICT:
import array_api_strict

from ...array_api_strict.fitting.fitting import (
DipoleFittingNet as DipoleFittingArrayAPIStrict,
)
else:
DipoleFittingArrayAPIStrict = object
from deepmd.utils.argcheck import (
fitting_dipole,
)
Expand Down Expand Up @@ -69,7 +86,11 @@ def skip_pt(self) -> bool:
tf_class = DipoleFittingTF
dp_class = DipoleFittingDP
pt_class = DipoleFittingPT
jax_class = DipoleFittingJAX
array_api_strict_class = DipoleFittingArrayAPIStrict
args = fitting_dipole()
skip_jax = not INSTALLED_JAX
skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT

def setUp(self):
CommonTest.setUp(self)
Expand Down Expand Up @@ -143,6 +164,26 @@ def eval_dp(self, dp_obj: Any) -> Any:
None,
)["dipole"]

def eval_jax(self, jax_obj: Any) -> Any:
return np.asarray(
jax_obj(
jnp.asarray(self.inputs),
jnp.asarray(self.atype.reshape(1, -1)),
jnp.asarray(self.gr),
None,
)["dipole"]
)

def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
return np.asarray(
array_api_strict_obj(
array_api_strict.asarray(self.inputs),
array_api_strict.asarray(self.atype.reshape(1, -1)),
array_api_strict.asarray(self.gr),
None,
)["dipole"]
)

def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
if backend == self.RefBackend.TF:
# shape is not same
Expand Down
41 changes: 41 additions & 0 deletions source/tests/consistent/fitting/test_polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
)

from ..common import (
INSTALLED_ARRAY_API_STRICT,
INSTALLED_JAX,
INSTALLED_PT,
INSTALLED_TF,
CommonTest,
Expand All @@ -32,6 +34,21 @@
from deepmd.tf.fit.polar import PolarFittingSeA as PolarFittingTF
else:
PolarFittingTF = object
if INSTALLED_JAX:
from deepmd.jax.env import (
jnp,
)
from deepmd.jax.fitting.fitting import PolarFittingNet as PolarFittingJAX
else:
PolarFittingJAX = object
if INSTALLED_ARRAY_API_STRICT:
import array_api_strict

from ...array_api_strict.fitting.fitting import (
PolarFittingNet as PolarFittingArrayAPIStrict,
)
else:
PolarFittingArrayAPIStrict = object
njzjz marked this conversation as resolved.
Show resolved Hide resolved
from deepmd.utils.argcheck import (
fitting_polar,
)
Expand Down Expand Up @@ -69,7 +86,11 @@ def skip_pt(self) -> bool:
tf_class = PolarFittingTF
dp_class = PolarFittingDP
pt_class = PolarFittingPT
jax_class = PolarFittingJAX
array_api_strict_class = PolarFittingArrayAPIStrict
args = fitting_polar()
skip_jax = not INSTALLED_JAX
skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT

def setUp(self):
CommonTest.setUp(self)
Expand Down Expand Up @@ -143,6 +164,26 @@ def eval_dp(self, dp_obj: Any) -> Any:
None,
)["polarizability"]

def eval_jax(self, jax_obj: Any) -> Any:
return np.asarray(
jax_obj(
jnp.asarray(self.inputs),
jnp.asarray(self.atype.reshape(1, -1)),
jnp.asarray(self.gr),
None,
)["polarizability"]
)

def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
return np.asarray(
array_api_strict_obj(
array_api_strict.asarray(self.inputs),
array_api_strict.asarray(self.atype.reshape(1, -1)),
array_api_strict.asarray(self.gr),
None,
)["polarizability"]
)

def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
if backend == self.RefBackend.TF:
# shape is not same
Expand Down