Skip to content

Commit

Permalink
Merge branch 'devel' into feat/zbl-consistency-test
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml authored Nov 1, 2024
2 parents 6bb4749 + 704db2f commit 0e4208d
Show file tree
Hide file tree
Showing 35 changed files with 476 additions and 74 deletions.
7 changes: 7 additions & 0 deletions deepmd/dpmodel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.utils.entry_point import (
load_entry_point,
)

from .common import (
DEFAULT_PRECISION,
PRECISION_DICT,
Expand Down Expand Up @@ -32,3 +36,6 @@
"get_deriv_name",
"get_hessian_name",
]


load_entry_point("deepmd.dpmodel")
13 changes: 8 additions & 5 deletions deepmd/dpmodel/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel.common import (
Expand Down Expand Up @@ -66,7 +67,7 @@ def __init__(
), f"number of atom types in {ii}th descriptor {self.descrpt_list[0].__class__.__name__} does not match others"
# if hybrid sel is larger than sub sel, the nlist needs to be cut for each type
hybrid_sel = self.get_sel()
self.nlist_cut_idx: list[np.ndarray] = []
nlist_cut_idx: list[np.ndarray] = []
if self.mixed_types() and not all(
descrpt.mixed_types() for descrpt in self.descrpt_list
):
Expand All @@ -92,7 +93,8 @@ def __init__(
cut_idx = np.concatenate(
[range(ss, ee) for ss, ee in zip(start_idx, end_idx)]
)
self.nlist_cut_idx.append(cut_idx)
nlist_cut_idx.append(cut_idx)
self.nlist_cut_idx = nlist_cut_idx

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -242,6 +244,7 @@ def call(
sw
The smooth switch function.
"""
xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist)
out_descriptor = []
out_gr = []
out_g2 = None
Expand All @@ -258,7 +261,7 @@ def call(
for descrpt, nci in zip(self.descrpt_list, self.nlist_cut_idx):
# cut the nlist to the correct length
if self.mixed_types() == descrpt.mixed_types():
nl = nlist[:, :, nci]
nl = xp.take(nlist, nci, axis=2)
else:
# mixed_types is True, but descrpt.mixed_types is False
assert nl_distinguish_types is not None
Expand All @@ -268,8 +271,8 @@ def call(
if gr is not None:
out_gr.append(gr)

out_descriptor = np.concatenate(out_descriptor, axis=-1)
out_gr = np.concatenate(out_gr, axis=-2) if out_gr else None
out_descriptor = xp.concat(out_descriptor, axis=-1)
out_gr = xp.concat(out_gr, axis=-2) if out_gr else None
return out_descriptor, out_gr, out_g2, out_h2, out_sw

@classmethod
Expand Down
10 changes: 7 additions & 3 deletions deepmd/dpmodel/fitting/dipole_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel import (
Expand Down Expand Up @@ -207,16 +208,19 @@ def call(
The atomic parameter. shape: nf x nloc x nap. nap being `numb_aparam`
"""
xp = array_api_compat.array_namespace(descriptor, atype)
nframes, nloc, _ = descriptor.shape
assert gr is not None, "Must provide the rotation matrix for dipole fitting."
# (nframes, nloc, m1)
out = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam)[
self.var_name
]
# (nframes * nloc, 1, m1)
out = out.reshape(-1, 1, self.embedding_width)
out = xp.reshape(out, (-1, 1, self.embedding_width))
# (nframes * nloc, m1, 3)
gr = gr.reshape(nframes * nloc, -1, 3)
gr = xp.reshape(gr, (nframes * nloc, -1, 3))
# (nframes, nloc, 3)
out = np.einsum("bim,bmj->bij", out, gr).squeeze(-2).reshape(nframes, nloc, 3)
# out = np.einsum("bim,bmj->bij", out, gr).squeeze(-2).reshape(nframes, nloc, 3)
out = out @ gr
out = xp.reshape(out, (nframes, nloc, 3))
return {self.var_name: out}
8 changes: 4 additions & 4 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,8 @@ def _call_common(
assert fparam is not None, "fparam should not be None"
if fparam.shape[-1] != self.numb_fparam:
raise ValueError(
"get an input fparam of dim {fparam.shape[-1]}, ",
"which is not consistent with {self.numb_fparam}.",
f"get an input fparam of dim {fparam.shape[-1]}, "
f"which is not consistent with {self.numb_fparam}."
)
fparam = (fparam - self.fparam_avg) * self.fparam_inv_std
fparam = xp.tile(
Expand All @@ -409,8 +409,8 @@ def _call_common(
assert aparam is not None, "aparam should not be None"
if aparam.shape[-1] != self.numb_aparam:
raise ValueError(
"get an input aparam of dim {aparam.shape[-1]}, ",
"which is not consistent with {self.numb_aparam}.",
f"get an input aparam of dim {aparam.shape[-1]}, "
f"which is not consistent with {self.numb_aparam}."
)
aparam = xp.reshape(aparam, [nf, nloc, self.numb_aparam])
aparam = (aparam - self.aparam_avg) * self.aparam_inv_std
Expand Down
71 changes: 41 additions & 30 deletions deepmd/dpmodel/fitting/polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.common import (
Expand All @@ -14,6 +15,9 @@
from deepmd.dpmodel import (
DEFAULT_PRECISION,
)
from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.fitting.base_fitting import (
BaseFitting,
)
Expand Down Expand Up @@ -124,23 +128,18 @@ def __init__(

self.embedding_width = embedding_width
self.fit_diag = fit_diag
self.scale = scale
if self.scale is None:
self.scale = [1.0 for _ in range(ntypes)]
if scale is None:
scale = [1.0 for _ in range(ntypes)]
else:
if isinstance(self.scale, list):
assert (
len(self.scale) == ntypes
), "Scale should be a list of length ntypes."
elif isinstance(self.scale, float):
self.scale = [self.scale for _ in range(ntypes)]
if isinstance(scale, list):
assert len(scale) == ntypes, "Scale should be a list of length ntypes."
elif isinstance(scale, float):
scale = [scale for _ in range(ntypes)]
else:
raise ValueError(
"Scale must be a list of float of length ntypes or a float."
)
self.scale = np.array(self.scale, dtype=GLOBAL_NP_FLOAT_PRECISION).reshape(
ntypes, 1
)
self.scale = np.array(scale, dtype=GLOBAL_NP_FLOAT_PRECISION).reshape(ntypes, 1)
self.shift_diag = shift_diag
self.constant_matrix = np.zeros(ntypes, dtype=GLOBAL_NP_FLOAT_PRECISION)
super().__init__(
Expand Down Expand Up @@ -192,8 +191,8 @@ def serialize(self) -> dict:
data["embedding_width"] = self.embedding_width
data["fit_diag"] = self.fit_diag
data["shift_diag"] = self.shift_diag
data["@variables"]["scale"] = self.scale
data["@variables"]["constant_matrix"] = self.constant_matrix
data["@variables"]["scale"] = to_numpy_array(self.scale)
data["@variables"]["constant_matrix"] = to_numpy_array(self.constant_matrix)
return data

@classmethod
Expand Down Expand Up @@ -276,6 +275,7 @@ def call(
The atomic parameter. shape: nf x nloc x nap. nap being `numb_aparam`
"""
xp = array_api_compat.array_namespace(descriptor, atype)
nframes, nloc, _ = descriptor.shape
assert (
gr is not None
Expand All @@ -284,28 +284,39 @@ def call(
out = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam)[
self.var_name
]
out = out * self.scale[atype]
# out = out * self.scale[atype, ...]
scale_atype = xp.reshape(
xp.take(self.scale, xp.reshape(atype, [-1]), axis=0), (*atype.shape, 1)
)
out = out * scale_atype
# (nframes * nloc, m1, 3)
gr = gr.reshape(nframes * nloc, -1, 3)
gr = xp.reshape(gr, (nframes * nloc, -1, 3))

if self.fit_diag:
out = out.reshape(-1, self.embedding_width)
out = np.einsum("ij,ijk->ijk", out, gr)
out = xp.reshape(out, (-1, self.embedding_width))
# out = np.einsum("ij,ijk->ijk", out, gr)
out = out[:, :, None] * gr
else:
out = out.reshape(-1, self.embedding_width, self.embedding_width)
out = (out + np.transpose(out, axes=(0, 2, 1))) / 2
out = np.einsum("bim,bmj->bij", out, gr) # (nframes * nloc, m1, 3)
out = np.einsum(
"bim,bmj->bij", np.transpose(gr, axes=(0, 2, 1)), out
) # (nframes * nloc, 3, 3)
out = out.reshape(nframes, nloc, 3, 3)
out = xp.reshape(out, (-1, self.embedding_width, self.embedding_width))
out = (out + xp.matrix_transpose(out)) / 2
# out = np.einsum("bim,bmj->bij", out, gr) # (nframes * nloc, m1, 3)
out = out @ gr
# out = np.einsum(
# "bim,bmj->bij", np.transpose(gr, axes=(0, 2, 1)), out
# ) # (nframes * nloc, 3, 3)
out = xp.matrix_transpose(gr) @ out
out = xp.reshape(out, (nframes, nloc, 3, 3))
if self.shift_diag:
bias = self.constant_matrix[atype]
# bias = self.constant_matrix[atype]
bias = xp.reshape(
xp.take(self.constant_matrix, xp.reshape(atype, [-1]), axis=0),
(nframes, nloc),
)
# (nframes, nloc, 1)
bias = np.expand_dims(bias, axis=-1) * self.scale[atype]
eye = np.eye(3, dtype=descriptor.dtype)
eye = np.tile(eye, (nframes, nloc, 1, 1))
bias = bias[..., None] * scale_atype
eye = xp.eye(3, dtype=descriptor.dtype)
eye = xp.tile(eye, (nframes, nloc, 1, 1))
# (nframes, nloc, 3, 3)
bias = np.expand_dims(bias, axis=-1) * eye
bias = bias[..., None] * eye
out = out + bias
return {"polarizability": out}
21 changes: 17 additions & 4 deletions deepmd/dpmodel/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,6 @@ def eval(
The output of the evaluation. The keys are the names of the output
variables, and the values are the corresponding output arrays.
"""
if fparam is not None or aparam is not None:
raise NotImplementedError
# convert all of the input to numpy array
atom_types = np.array(atom_types, dtype=np.int32)
coords = np.array(coords)
Expand All @@ -216,7 +214,7 @@ def eval(
)
request_defs = self._get_request_defs(atomic)
out = self._eval_func(self._eval_model, numb_test, natoms)(
coords, cells, atom_types, request_defs
coords, cells, atom_types, fparam, aparam, request_defs
)
return dict(
zip(
Expand Down Expand Up @@ -306,6 +304,8 @@ def _eval_model(
coords: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
fparam: Optional[np.ndarray],
aparam: Optional[np.ndarray],
request_defs: list[OutputVariableDef],
):
model = self.dp
Expand All @@ -323,12 +323,25 @@ def _eval_model(
box_input = cells.reshape([-1, 3, 3])
else:
box_input = None
if fparam is not None:
fparam_input = fparam.reshape(nframes, self.get_dim_fparam())
else:
fparam_input = None
if aparam is not None:
aparam_input = aparam.reshape(nframes, natoms, self.get_dim_aparam())
else:
aparam_input = None

do_atomic_virial = any(
x.category == OutputVariableCategory.DERV_C_REDU for x in request_defs
)
batch_output = model(
coord_input, type_input, box=box_input, do_atomic_virial=do_atomic_virial
coord_input,
type_input,
box=box_input,
fparam=fparam_input,
aparam=aparam_input,
do_atomic_virial=do_atomic_virial,
)
if isinstance(batch_output, tuple):
batch_output = batch_output[0]
Expand Down
2 changes: 1 addition & 1 deletion deepmd/entrypoints/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test(
log.info(f"# testing system : {system}")

# create data class
tmap = dp.get_type_map() if isinstance(dp, DeepPot) else None
tmap = dp.get_type_map()
data = DeepmdData(
system,
set_prefix="set",
Expand Down
6 changes: 6 additions & 0 deletions deepmd/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""JAX backend."""

from deepmd.utils.entry_point import (
load_entry_point,
)

load_entry_point("deepmd.jax")
4 changes: 4 additions & 0 deletions deepmd/jax/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from deepmd.jax.descriptor.dpa1 import (
DescrptDPA1,
)
from deepmd.jax.descriptor.hybrid import (
DescrptHybrid,
)
from deepmd.jax.descriptor.se_e2_a import (
DescrptSeA,
)
Expand All @@ -13,4 +16,5 @@
"DescrptSeA",
"DescrptSeR",
"DescrptDPA1",
"DescrptHybrid",
]
26 changes: 26 additions & 0 deletions deepmd/jax/descriptor/hybrid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.descriptor.hybrid import DescrptHybrid as DescrptHybridDP
from deepmd.jax.common import (
ArrayAPIVariable,
flax_module,
to_jax_array,
)
from deepmd.jax.descriptor.base_descriptor import (
BaseDescriptor,
)


@BaseDescriptor.register("hybrid")
@flax_module
class DescrptHybrid(DescrptHybridDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"nlist_cut_idx"}:
value = [ArrayAPIVariable(to_jax_array(vv)) for vv in value]
elif name in {"descrpt_list"}:
value = [BaseDescriptor.deserialize(vv.serialize()) for vv in value]

return super().__setattr__(name, value)
4 changes: 4 additions & 0 deletions deepmd/jax/fitting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.jax.fitting.fitting import (
DipoleFittingNet,
DOSFittingNet,
EnergyFittingNet,
PolarFittingNet,
)

__all__ = [
"EnergyFittingNet",
"DOSFittingNet",
"DipoleFittingNet",
"PolarFittingNet",
]
Loading

0 comments on commit 0e4208d

Please sign in to comment.