Skip to content

Commit

Permalink
Merge branch 'devel' into data_stat
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd authored Feb 7, 2024
2 parents c1369c9 + a7153b1 commit 67a22fa
Show file tree
Hide file tree
Showing 12 changed files with 231 additions and 73 deletions.
6 changes: 3 additions & 3 deletions .github/ISSUE_TEMPLATE/bug-report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ body:
validations:
required: true
- type: input
id: tf-version
id: backend-version
attributes:
label: TensorFlow Version
description: "The version will be printed when running DeePMD-kit."
label: Backend and its version
description: "The backend and its version will be printed when running DeePMD-kit, e.g. TensorFlow v2.15.0."
validations:
required: true
- type: dropdown
Expand Down
6 changes: 3 additions & 3 deletions .github/ISSUE_TEMPLATE/generic-issue.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ body:
validations:
required: true
- type: input
id: tf-version
id: backend-version
attributes:
label: TensorFlow Version
description: "The version will be printed when running DeePMD-kit."
label: Backend and its version
description: "The backend and its version will be printed when running DeePMD-kit, e.g. TensorFlow v2.15.0."
validations:
required: true
- type: textarea
Expand Down
6 changes: 5 additions & 1 deletion deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,11 @@ def output_def(self):
return FittingOutputDef(
[
OutputVariableDef(
self.var_name, [self.dim_out], reduciable=True, differentiable=True
self.var_name,
[self.dim_out],
reduciable=True,
r_differentiable=True,
c_differentiable=True,
),
]
)
Expand Down
5 changes: 4 additions & 1 deletion deepmd/dpmodel/model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,10 @@ def do_grad_(
) -> bool:
"""Tell if the output variable `var_name` is differentiable."""
assert var_name is not None
return self.fitting_output_def()[var_name].differentiable
return (
self.fitting_output_def()[var_name].r_differentiable
or self.fitting_output_def()[var_name].c_differentiable
)

setattr(BAM, fwd_method_name, BAM.fwd)
delattr(BAM, "fwd")
Expand Down
6 changes: 5 additions & 1 deletion deepmd/dpmodel/model/pair_tab_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ def fitting_output_def(self) -> FittingOutputDef:
return FittingOutputDef(
[
OutputVariableDef(
name="energy", shape=[1], reduciable=True, differentiable=True
name="energy",
shape=[1],
reduciable=True,
r_differentiable=True,
c_differentiable=True,
)
]
)
Expand Down
10 changes: 8 additions & 2 deletions deepmd/dpmodel/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,13 @@ def fit_output_to_model_output(
if vdef.reduciable:
kk_redu = get_reduce_name(kk)
model_ret[kk_redu] = np.sum(vv, axis=atom_axis)
if vdef.differentiable:
if vdef.r_differentiable:
kk_derv_r, kk_derv_c = get_deriv_name(kk)
# name-holders
model_ret[kk_derv_r] = None
if vdef.c_differentiable:
assert vdef.r_differentiable
kk_derv_r, kk_derv_c = get_deriv_name(kk)
model_ret[kk_derv_c] = None
return model_ret

Expand All @@ -57,10 +60,13 @@ def communicate_extended_output(
if vdef.reduciable:
kk_redu = get_reduce_name(kk)
new_ret[kk_redu] = model_ret[kk_redu]
if vdef.differentiable:
if vdef.r_differentiable:
kk_derv_r, kk_derv_c = get_deriv_name(kk)
# name holders
new_ret[kk_derv_r] = None
if vdef.c_differentiable:
assert vdef.r_differentiable
kk_derv_r, kk_derv_c = get_deriv_name(kk)
new_ret[kk_derv_c] = None
new_ret[kk_derv_c + "_redu"] = None
if not do_atomic_virial:
Expand Down
48 changes: 34 additions & 14 deletions deepmd/dpmodel/output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,11 @@ def __call__(
if dd.reduciable:
rk = get_reduce_name(kk)
check_var(ret[rk], self.md[rk])
if dd.differentiable:
if dd.r_differentiable:
dnr, dnc = get_deriv_name(kk)
check_var(ret[dnr], self.md[dnr])
if dd.c_differentiable:
assert dd.r_differentiable
check_var(ret[dnc], self.md[dnc])
return ret

Expand Down Expand Up @@ -160,10 +162,16 @@ class OutputVariableDef:
dipole should be [3], polarizabilty should be [3,3].
reduciable
If the variable is reduced.
differentiable
r_differentiable
If the variable is differentiated with respect to coordinates
of atoms and cell tensor (pbc case). Only reduciable variable
of atoms. Only reduciable variable are differentiable.
Negative derivative w.r.t. coordinates will be calcualted. (e.g. force)
c_differentiable
If the variable is differentiated with respect to the
cell tensor (pbc case). Only reduciable variable
are differentiable.
Virial, the transposed negative gradient with cell tensor times
cell tensor, will be calculated, see eq 40 JCP 159, 054801 (2023).
category : int
The category of the output variable.
"""
Expand All @@ -173,19 +181,25 @@ def __init__(
name: str,
shape: List[int],
reduciable: bool = False,
differentiable: bool = False,
r_differentiable: bool = False,
c_differentiable: bool = False,
atomic: bool = True,
category: int = OutputVariableCategory.OUT.value,
):
self.name = name
self.shape = list(shape)
self.atomic = atomic
self.reduciable = reduciable
self.differentiable = differentiable
if not self.reduciable and self.differentiable:
raise ValueError("only reduciable variable are differentiable")
self.r_differentiable = r_differentiable
self.c_differentiable = c_differentiable
if self.c_differentiable and not self.r_differentiable:
raise ValueError("c differentiable requires r_differentiable")
if not self.reduciable and self.r_differentiable:
raise ValueError("only reduciable variable are r differentiable")
if not self.reduciable and self.c_differentiable:
raise ValueError("only reduciable variable are c differentiable")
if self.reduciable and not self.atomic:
raise ValueError("only reduciable variable should be atomic")
raise ValueError("a reduciable variable should be atomic")
self.category = category


Expand Down Expand Up @@ -358,7 +372,8 @@ def do_reduce(
rk,
vv.shape,
reduciable=False,
differentiable=False,
r_differentiable=False,
c_differentiable=False,
atomic=False,
category=apply_operation(vv, OutputVariableOperation.REDU),
)
Expand All @@ -371,21 +386,26 @@ def do_derivative(
def_derv_r: Dict[str, OutputVariableDef] = {}
def_derv_c: Dict[str, OutputVariableDef] = {}
for kk, vv in def_outp_data.items():
if vv.differentiable:
rkr, rkc = get_deriv_name(kk)
rkr, rkc = get_deriv_name(kk)
if vv.r_differentiable:
def_derv_r[rkr] = OutputVariableDef(
rkr,
vv.shape + [3], # noqa: RUF005
reduciable=False,
differentiable=False,
r_differentiable=False,
c_differentiable=False,
atomic=True,
category=apply_operation(vv, OutputVariableOperation.DERV_R),
)
if vv.c_differentiable:
assert vv.r_differentiable
rkr, rkc = get_deriv_name(kk)
def_derv_c[rkc] = OutputVariableDef(
rkc,
vv.shape + [3, 3], # noqa: RUF005
vv.shape + [9], # noqa: RUF005
reduciable=True,
differentiable=False,
r_differentiable=False,
c_differentiable=False,
atomic=True,
category=apply_operation(vv, OutputVariableOperation.DERV_C),
)
Expand Down
6 changes: 5 additions & 1 deletion deepmd/pt/model/model/pair_tab_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,11 @@ def fitting_output_def(self) -> FittingOutputDef:
return FittingOutputDef(
[
OutputVariableDef(
name="energy", shape=[1], reduciable=True, differentiable=True
name="energy",
shape=[1],
reduciable=True,
r_differentiable=True,
c_differentiable=True,
)
]
)
Expand Down
64 changes: 43 additions & 21 deletions deepmd/pt/model/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def task_deriv_one(
atom_energy: torch.Tensor,
energy: torch.Tensor,
extended_coord: torch.Tensor,
do_virial: bool = True,
do_atomic_virial: bool = False,
):
faked_grad = torch.ones_like(energy)
Expand All @@ -65,13 +66,16 @@ def task_deriv_one(
)[0]
assert extended_force is not None
extended_force = -extended_force
extended_virial = extended_force.unsqueeze(-1) @ extended_coord.unsqueeze(-2)
# the correction sums to zero, which does not contribute to global virial
if do_atomic_virial:
extended_virial_corr = atomic_virial_corr(extended_coord, atom_energy)
extended_virial = extended_virial + extended_virial_corr
# to [...,3,3] -> [...,9]
extended_virial = extended_virial.view(list(extended_virial.shape[:-2]) + [9]) # noqa:RUF005
if do_virial:
extended_virial = extended_force.unsqueeze(-1) @ extended_coord.unsqueeze(-2)
# the correction sums to zero, which does not contribute to global virial
if do_atomic_virial:
extended_virial_corr = atomic_virial_corr(extended_coord, atom_energy)
extended_virial = extended_virial + extended_virial_corr
# to [...,3,3] -> [...,9]
extended_virial = extended_virial.view(list(extended_virial.shape[:-2]) + [9]) # noqa:RUF005
else:
extended_virial = None
return extended_force, extended_virial


Expand All @@ -97,6 +101,7 @@ def take_deriv(
svv: torch.Tensor,
vdef: OutputVariableDef,
coord_ext: torch.Tensor,
do_virial: bool = False,
do_atomic_virial: bool = False,
):
size = 1
Expand All @@ -110,16 +115,25 @@ def take_deriv(
for vvi, svvi in zip(split_vv1, split_svv1):
# nf x nloc x 3, nf x nloc x 9
ffi, aviri = task_deriv_one(
vvi, svvi, coord_ext, do_atomic_virial=do_atomic_virial
vvi,
svvi,
coord_ext,
do_virial=do_virial,
do_atomic_virial=do_atomic_virial,
)
# nf x nloc x 1 x 3, nf x nloc x 1 x 9
ffi = ffi.unsqueeze(-2)
aviri = aviri.unsqueeze(-2)
split_ff.append(ffi)
split_avir.append(aviri)
if do_virial:
assert aviri is not None
aviri = aviri.unsqueeze(-2)
split_avir.append(aviri)
# nf x nloc x v_dim x 3, nf x nloc x v_dim x 9
ff = torch.concat(split_ff, dim=-2)
avir = torch.concat(split_avir, dim=-2)
if do_virial:
avir = torch.concat(split_avir, dim=-2)
else:
avir = None
return ff, avir


Expand All @@ -141,18 +155,23 @@ def fit_output_to_model_output(
if vdef.reduciable:
kk_redu = get_reduce_name(kk)
model_ret[kk_redu] = torch.sum(vv, dim=atom_axis)
if vdef.differentiable:
if vdef.r_differentiable:
kk_derv_r, kk_derv_c = get_deriv_name(kk)
dr, dc = take_deriv(
vv,
model_ret[kk_redu],
vdef,
coord_ext,
do_virial=vdef.c_differentiable,
do_atomic_virial=do_atomic_virial,
)
model_ret[kk_derv_r] = dr
model_ret[kk_derv_c] = dc
model_ret[kk_derv_c + "_redu"] = torch.sum(model_ret[kk_derv_c], dim=1)
if vdef.c_differentiable:
assert dc is not None
model_ret[kk_derv_c] = dc
model_ret[kk_derv_c + "_redu"] = torch.sum(
model_ret[kk_derv_c], dim=1
)
return model_ret


Expand All @@ -174,12 +193,12 @@ def communicate_extended_output(
if vdef.reduciable:
kk_redu = get_reduce_name(kk)
new_ret[kk_redu] = model_ret[kk_redu]
if vdef.differentiable:
# nf x nloc
vldims = get_leading_dims(vv, vdef)
# nf x nall
mldims = list(mapping.shape)
kk_derv_r, kk_derv_c = get_deriv_name(kk)
# nf x nloc
vldims = get_leading_dims(vv, vdef)
# nf x nall
mldims = list(mapping.shape)
kk_derv_r, kk_derv_c = get_deriv_name(kk)
if vdef.r_differentiable:
# vdim x 3
derv_r_ext_dims = list(vdef.shape) + [3] # noqa:RUF005
mapping = mapping.view(mldims + [1] * len(derv_r_ext_dims)).expand(
Expand All @@ -196,10 +215,13 @@ def communicate_extended_output(
src=model_ret[kk_derv_r],
reduce="sum",
)
if vdef.c_differentiable:
assert vdef.r_differentiable
derv_c_ext_dims = list(vdef.shape) + [9] # noqa:RUF005
# nf x nloc x nvar x 3 -> nf x nloc x nvar x 9
mapping = torch.tile(
mapping, [1] * (len(mldims) + len(vdef.shape)) + [3]
mapping,
[1] * (len(mldims) + len(vdef.shape)) + [3],
)
virial = torch.zeros(
vldims + derv_c_ext_dims, dtype=vv.dtype, device=vv.device
Expand Down
12 changes: 10 additions & 2 deletions deepmd/pt/model/task/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,18 @@ def output_def(self):
return FittingOutputDef(
[
OutputVariableDef(
"updated_coord", [3], reduciable=False, differentiable=False
"updated_coord",
[3],
reduciable=False,
r_differentiable=False,
c_differentiable=False,
),
OutputVariableDef(
"logits", [-1], reduciable=False, differentiable=False
"logits",
[-1],
reduciable=False,
r_differentiable=False,
c_differentiable=False,
),
]
)
Expand Down
Loading

0 comments on commit 67a22fa

Please sign in to comment.