Skip to content

Commit

Permalink
Merge branch 'devel' into spin_rf
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd authored Feb 21, 2024
2 parents 5841e59 + d629616 commit 96426c3
Show file tree
Hide file tree
Showing 65 changed files with 1,190 additions and 274 deletions.
1 change: 1 addition & 0 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
DescrptSeA,
)
from deepmd.dpmodel.fitting import ( # noqa # TODO: should import all fittings!
EnergyFittingNet,
InvarFitting,
)
from deepmd.dpmodel.output_def import (
Expand Down
36 changes: 25 additions & 11 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,33 +118,47 @@ def serialize(self) -> dict:
def deserialize(cls):
pass

def do_grad(
def do_grad_r(
self,
var_name: Optional[str] = None,
) -> bool:
"""Tell if the output variable `var_name` is differentiable.
if var_name is None, returns if any of the variable is differentiable.
"""Tell if the output variable `var_name` is r_differentiable.
if var_name is None, returns if any of the variable is r_differentiable.
"""
odef = self.fitting_output_def()
if var_name is None:
require: List[bool] = []
for vv in odef.keys():
require.append(self.do_grad_(vv))
require.append(self.do_grad_(vv, "r"))
return any(require)
else:
return self.do_grad_(var_name)
return self.do_grad_(var_name, "r")

def do_grad_(
def do_grad_c(
self,
var_name: str,
var_name: Optional[str] = None,
) -> bool:
"""Tell if the output variable `var_name` is c_differentiable.
if var_name is None, returns if any of the variable is c_differentiable.
"""
odef = self.fitting_output_def()
if var_name is None:
require: List[bool] = []
for vv in odef.keys():
require.append(self.do_grad_(vv, "c"))
return any(require)
else:
return self.do_grad_(var_name, "c")

def do_grad_(self, var_name: str, base: str) -> bool:
"""Tell if the output variable `var_name` is differentiable."""
assert var_name is not None
return (
self.fitting_output_def()[var_name].r_differentiable
or self.fitting_output_def()[var_name].c_differentiable
)
assert base in ["c", "r"]
if base == "c":
return self.fitting_output_def()[var_name].c_differentiable
return self.fitting_output_def()[var_name].r_differentiable

setattr(BAM, fwd_method_name, BAM.fwd)
delattr(BAM, "fwd")
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/fitting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from .dipole_fitting import (
DipoleFitting,
)
from .ener_fitting import (
EnergyFittingNet,
)
from .invar_fitting import (
InvarFitting,
)
Expand All @@ -16,5 +19,6 @@
"InvarFitting",
"make_base_fitting",
"DipoleFitting",
"EnergyFittingNet",
"PolarFitting",
]
19 changes: 15 additions & 4 deletions deepmd/dpmodel/fitting/dipole_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,14 @@ class DipoleFitting(GeneralFitting):
mixed_types
If true, use a uniform fitting net for all atom types, otherwise use
different fitting nets for different atom types.
exclude_types: List[int]
exclude_types
Atomic contributions of the excluded atom types are set zero.
r_differentiable
If the variable is differentiated with respect to coordinates of atoms.
Only reduciable variable are differentiable.
c_differentiable
If the variable is differentiated with respect to the cell tensor (pbc case).
Only reduciable variable are differentiable.
"""

def __init__(
Expand All @@ -94,6 +99,8 @@ def __init__(
spin: Any = None,
mixed_types: bool = False,
exclude_types: List[int] = [],
r_differentiable: bool = True,
c_differentiable: bool = True,
old_impl=False,
):
# seed, uniform_seed are not included
Expand All @@ -109,6 +116,8 @@ def __init__(
raise NotImplementedError("atom_ener is not implemented")

self.embedding_width = embedding_width
self.r_differentiable = r_differentiable
self.c_differentiable = c_differentiable
super().__init__(
var_name=var_name,
ntypes=ntypes,
Expand Down Expand Up @@ -139,6 +148,8 @@ def serialize(self) -> dict:
data = super().serialize()
data["embedding_width"] = self.embedding_width
data["old_impl"] = self.old_impl
data["r_differentiable"] = self.r_differentiable
data["c_differentiable"] = self.c_differentiable
return data

def output_def(self):
Expand All @@ -148,8 +159,8 @@ def output_def(self):
self.var_name,
[3],
reduciable=True,
r_differentiable=True,
c_differentiable=True,
r_differentiable=self.r_differentiable,
c_differentiable=self.c_differentiable,
),
]
)
Expand Down
4 changes: 2 additions & 2 deletions deepmd/dpmodel/fitting/polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ def output_def(self):
self.var_name,
[3, 3],
reduciable=True,
r_differentiable=True,
c_differentiable=True,
r_differentiable=False,
c_differentiable=False,
),
]
)
Expand Down
5 changes: 4 additions & 1 deletion deepmd/dpmodel/model/dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,7 @@
make_model,
)

DPModel = make_model(DPAtomicModel)

# use "class" to resolve "Variable not allowed in type expression"
class DPModel(make_model(DPAtomicModel)):
pass
5 changes: 4 additions & 1 deletion deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import numpy as np

from deepmd.dpmodel.common import (
NativeOP,
)
from deepmd.dpmodel.output_def import (
ModelOutputDef,
)
Expand Down Expand Up @@ -45,7 +48,7 @@ def make_model(T_AtomicModel):
"""

class CM(T_AtomicModel):
class CM(T_AtomicModel, NativeOP):
def __init__(
self,
*args,
Expand Down
41 changes: 41 additions & 0 deletions deepmd/dpmodel/model/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.descriptor.se_e2_a import (
DescrptSeA,
)
from deepmd.dpmodel.fitting.ener_fitting import (
EnergyFittingNet,
)
from deepmd.dpmodel.model.dp_model import (
DPModel,
)


def get_model(data: dict) -> DPModel:
"""Get a DPModel from a dictionary.
Parameters
----------
data : dict
The data to construct the model.
"""
descriptor_type = data["descriptor"].pop("type")
fitting_type = data["fitting_net"].pop("type")
if descriptor_type == "se_e2_a":
descriptor = DescrptSeA(
**data["descriptor"],
)
else:
raise ValueError(f"Unknown descriptor type {descriptor_type}")
if fitting_type == "ener":
fitting = EnergyFittingNet(
ntypes=descriptor.get_ntypes(),
dim_descrpt=descriptor.get_dim_out(),
**data["fitting_net"],
)
else:
raise ValueError(f"Unknown fitting type {fitting_type}")
return DPModel(
descriptor=descriptor,
fitting=fitting,
type_map=data["type_map"],
)
18 changes: 11 additions & 7 deletions deepmd/infer/deep_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def eval(
**kwargs,
)
sel_natoms = self._get_sel_natoms(atom_types[0])
if sel_natoms == 0:
sel_natoms = atom_types.shape[-1] # set to natoms
if atomic:
return results[self.output_tensor_name].reshape(nframes, sel_natoms, -1)
else:
Expand Down Expand Up @@ -184,22 +186,24 @@ def eval_full(
aparam=aparam,
**kwargs,
)

sel_natoms = self._get_sel_natoms(atom_types[0])
if sel_natoms == 0:
sel_natoms = atom_types.shape[-1] # set to natoms
energy = results[f"{self.output_tensor_name}_redu"].reshape(nframes, -1)
force = results[f"{self.output_tensor_name}_derv_r"].reshape(
nframes, -1, natoms, 3
)
virial = results[f"{self.output_tensor_name}_derv_c_redu"].reshape(
nframes, -1, 9
)
atomic_energy = results[self.output_tensor_name].reshape(
nframes, sel_natoms, -1
)
atomic_virial = results[f"{self.output_tensor_name}_derv_c"].reshape(
nframes, -1, natoms, 9
)

if atomic:
atomic_energy = results[self.output_tensor_name].reshape(
nframes, sel_natoms, -1
)
atomic_virial = results[f"{self.output_tensor_name}_derv_c"].reshape(
nframes, -1, natoms, 9
)
return (
energy,
force,
Expand Down
49 changes: 29 additions & 20 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,13 +343,17 @@ def _eval_model(
natoms = len(atom_types[0])

coord_input = torch.tensor(
coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION
).to(DEVICE)
type_input = torch.tensor(atom_types, dtype=torch.long).to(DEVICE)
coords.reshape([-1, natoms, 3]),
dtype=GLOBAL_PT_FLOAT_PRECISION,
device=DEVICE,
)
type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE)
if cells is not None:
box_input = torch.tensor(
cells.reshape([-1, 3, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION
).to(DEVICE)
cells.reshape([-1, 3, 3]),
dtype=GLOBAL_PT_FLOAT_PRECISION,
device=DEVICE,
)
else:
box_input = None

Expand All @@ -369,6 +373,9 @@ def _eval_model(
shape = self._get_output_shape(odef, nframes, natoms)
out = batch_output[pt_name].reshape(shape).detach().cpu().numpy()
results.append(out)
else:
shape = self._get_output_shape(odef, nframes, natoms)
results.append(np.full(np.abs(shape), np.nan)) # this is kinda hacky
return tuple(results)

def _get_output_shape(self, odef, nframes, natoms):
Expand Down Expand Up @@ -420,7 +427,7 @@ def eval_model(
if cells is not None:
assert isinstance(cells, torch.Tensor), err_msg
assert isinstance(atom_types, torch.Tensor) or isinstance(atom_types, list)
atom_types = torch.tensor(atom_types, dtype=torch.long).to(DEVICE)
atom_types = torch.tensor(atom_types, dtype=torch.long, device=DEVICE)
elif isinstance(coords, np.ndarray):
if cells is not None:
assert isinstance(cells, np.ndarray), err_msg
Expand All @@ -441,17 +448,17 @@ def eval_model(
natoms = len(atom_types[0])

coord_input = torch.tensor(
coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION
).to(DEVICE)
type_input = torch.tensor(atom_types, dtype=torch.long).to(DEVICE)
coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE)
box_input = None
if cells is None:
pbc = False
else:
pbc = True
box_input = torch.tensor(
cells.reshape([-1, 3, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION
).to(DEVICE)
cells.reshape([-1, 3, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
num_iter = int((nframes + infer_batch_size - 1) / infer_batch_size)

for ii in range(num_iter):
Expand Down Expand Up @@ -527,35 +534,37 @@ def eval_model(
energy_out = (
torch.cat(energy_out)
if energy_out
else torch.zeros([nframes, 1], dtype=GLOBAL_PT_FLOAT_PRECISION).to(DEVICE)
else torch.zeros(
[nframes, 1], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
)
atomic_energy_out = (
torch.cat(atomic_energy_out)
if atomic_energy_out
else torch.zeros([nframes, natoms, 1], dtype=GLOBAL_PT_FLOAT_PRECISION).to(
DEVICE
else torch.zeros(
[nframes, natoms, 1], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
)
force_out = (
torch.cat(force_out)
if force_out
else torch.zeros([nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION).to(
DEVICE
else torch.zeros(
[nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
)
virial_out = (
torch.cat(virial_out)
if virial_out
else torch.zeros([nframes, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION).to(
DEVICE
else torch.zeros(
[nframes, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
)
atomic_virial_out = (
torch.cat(atomic_virial_out)
if atomic_virial_out
else torch.zeros(
[nframes, natoms, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION
).to(DEVICE)
[nframes, natoms, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
)
updated_coord_out = torch.cat(updated_coord_out) if updated_coord_out else None
logits_out = torch.cat(logits_out) if logits_out else None
Expand Down
6 changes: 4 additions & 2 deletions deepmd/pt/infer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ def __init__(

@staticmethod
def get_data(data):
batch_data = next(iter(data))
with torch.device("cpu"):
batch_data = next(iter(data))
for key in batch_data.keys():
if key == "sid" or key == "fid":
continue
Expand Down Expand Up @@ -235,7 +236,8 @@ def run(self):
), # setting to 0 diverges the behavior of its iterator; should be >=1
drop_last=False,
)
data = iter(dataloader)
with torch.device("cpu"):
data = iter(dataloader)

single_results = {}
sum_natoms = 0
Expand Down
Loading

0 comments on commit 96426c3

Please sign in to comment.