Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support separate r_differentiable and c_differentiable #3240

Merged
merged 2 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,11 @@ def output_def(self):
return FittingOutputDef(
[
OutputVariableDef(
self.var_name, [self.dim_out], reduciable=True, differentiable=True
self.var_name,
[self.dim_out],
reduciable=True,
r_differentiable=True,
c_differentiable=True,
),
]
)
Expand Down
5 changes: 4 additions & 1 deletion deepmd/dpmodel/model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ def do_grad_(
) -> bool:
"""Tell if the output variable `var_name` is differentiable."""
assert var_name is not None
return self.fitting_output_def()[var_name].differentiable
return (
self.fitting_output_def()[var_name].r_differentiable
or self.fitting_output_def()[var_name].c_differentiable
)

setattr(BAM, fwd_method_name, BAM.fwd)
delattr(BAM, "fwd")
Expand Down
6 changes: 5 additions & 1 deletion deepmd/dpmodel/model/pair_tab_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ def fitting_output_def(self) -> FittingOutputDef:
return FittingOutputDef(
[
OutputVariableDef(
name="energy", shape=[1], reduciable=True, differentiable=True
name="energy",
shape=[1],
reduciable=True,
r_differentiable=True,
c_differentiable=True,
)
]
)
Expand Down
10 changes: 8 additions & 2 deletions deepmd/dpmodel/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,13 @@ def fit_output_to_model_output(
if vdef.reduciable:
kk_redu = get_reduce_name(kk)
model_ret[kk_redu] = np.sum(vv, axis=atom_axis)
if vdef.differentiable:
if vdef.r_differentiable:
kk_derv_r, kk_derv_c = get_deriv_name(kk)
# name-holders
model_ret[kk_derv_r] = None
if vdef.c_differentiable:
assert vdef.r_differentiable
kk_derv_r, kk_derv_c = get_deriv_name(kk)
model_ret[kk_derv_c] = None
return model_ret

Expand All @@ -57,10 +60,13 @@ def communicate_extended_output(
if vdef.reduciable:
kk_redu = get_reduce_name(kk)
new_ret[kk_redu] = model_ret[kk_redu]
if vdef.differentiable:
if vdef.r_differentiable:
kk_derv_r, kk_derv_c = get_deriv_name(kk)
# name holders
new_ret[kk_derv_r] = None
if vdef.c_differentiable:
assert vdef.r_differentiable
kk_derv_r, kk_derv_c = get_deriv_name(kk)
new_ret[kk_derv_c] = None
new_ret[kk_derv_c + "_redu"] = None
if not do_atomic_virial:
Expand Down
48 changes: 34 additions & 14 deletions deepmd/dpmodel/output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,11 @@
if dd.reduciable:
rk = get_reduce_name(kk)
check_var(ret[rk], self.md[rk])
if dd.differentiable:
if dd.r_differentiable:
dnr, dnc = get_deriv_name(kk)
check_var(ret[dnr], self.md[dnr])
if dd.c_differentiable:
assert dd.r_differentiable
check_var(ret[dnc], self.md[dnc])
return ret

Expand Down Expand Up @@ -160,10 +162,16 @@
dipole should be [3], polarizabilty should be [3,3].
reduciable
If the variable is reduced.
differentiable
r_differentiable
If the variable is differentiated with respect to coordinates
of atoms and cell tensor (pbc case). Only reduciable variable
of atoms. Only reduciable variable are differentiable.
Negative derivative w.r.t. coordinates will be calcualted. (e.g. force)
c_differentiable
If the variable is differentiated with respect to the
cell tensor (pbc case). Only reduciable variable
are differentiable.
Virial, the transposed negative gradient with cell tensor times
cell tensor, will be calculated, see eq 40 JCP 159, 054801 (2023).
category : int
The category of the output variable.
"""
Expand All @@ -173,19 +181,25 @@
name: str,
shape: List[int],
reduciable: bool = False,
differentiable: bool = False,
r_differentiable: bool = False,
c_differentiable: bool = False,
atomic: bool = True,
category: int = OutputVariableCategory.OUT.value,
):
self.name = name
self.shape = list(shape)
self.atomic = atomic
self.reduciable = reduciable
self.differentiable = differentiable
if not self.reduciable and self.differentiable:
raise ValueError("only reduciable variable are differentiable")
self.r_differentiable = r_differentiable
self.c_differentiable = c_differentiable
if self.c_differentiable and not self.r_differentiable:
raise ValueError("c differentiable requires r_differentiable")
if not self.reduciable and self.r_differentiable:
raise ValueError("only reduciable variable are r differentiable")
if not self.reduciable and self.c_differentiable:
raise ValueError("only reduciable variable are c differentiable")

Check warning on line 200 in deepmd/dpmodel/output_def.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/output_def.py#L200

Added line #L200 was not covered by tests
if self.reduciable and not self.atomic:
raise ValueError("only reduciable variable should be atomic")
raise ValueError("a reduciable variable should be atomic")
self.category = category


Expand Down Expand Up @@ -358,7 +372,8 @@
rk,
vv.shape,
reduciable=False,
differentiable=False,
r_differentiable=False,
c_differentiable=False,
atomic=False,
category=apply_operation(vv, OutputVariableOperation.REDU),
)
Expand All @@ -371,21 +386,26 @@
def_derv_r: Dict[str, OutputVariableDef] = {}
def_derv_c: Dict[str, OutputVariableDef] = {}
for kk, vv in def_outp_data.items():
if vv.differentiable:
rkr, rkc = get_deriv_name(kk)
rkr, rkc = get_deriv_name(kk)
if vv.r_differentiable:
def_derv_r[rkr] = OutputVariableDef(
rkr,
vv.shape + [3], # noqa: RUF005
reduciable=False,
differentiable=False,
r_differentiable=False,
njzjz marked this conversation as resolved.
Show resolved Hide resolved
c_differentiable=False,
atomic=True,
category=apply_operation(vv, OutputVariableOperation.DERV_R),
)
if vv.c_differentiable:
assert vv.r_differentiable
rkr, rkc = get_deriv_name(kk)
def_derv_c[rkc] = OutputVariableDef(
rkc,
vv.shape + [3, 3], # noqa: RUF005
vv.shape + [9], # noqa: RUF005
reduciable=True,
differentiable=False,
r_differentiable=False,
c_differentiable=False,
atomic=True,
category=apply_operation(vv, OutputVariableOperation.DERV_C),
)
Expand Down
6 changes: 5 additions & 1 deletion deepmd/pt/model/model/pair_tab_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,11 @@ def fitting_output_def(self) -> FittingOutputDef:
return FittingOutputDef(
[
OutputVariableDef(
name="energy", shape=[1], reduciable=True, differentiable=True
name="energy",
shape=[1],
reduciable=True,
r_differentiable=True,
c_differentiable=True,
)
]
)
Expand Down
64 changes: 43 additions & 21 deletions deepmd/pt/model/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
atom_energy: torch.Tensor,
energy: torch.Tensor,
extended_coord: torch.Tensor,
do_virial: bool = True,
do_atomic_virial: bool = False,
):
faked_grad = torch.ones_like(energy)
Expand All @@ -65,13 +66,16 @@
)[0]
assert extended_force is not None
extended_force = -extended_force
extended_virial = extended_force.unsqueeze(-1) @ extended_coord.unsqueeze(-2)
# the correction sums to zero, which does not contribute to global virial
if do_atomic_virial:
extended_virial_corr = atomic_virial_corr(extended_coord, atom_energy)
extended_virial = extended_virial + extended_virial_corr
# to [...,3,3] -> [...,9]
extended_virial = extended_virial.view(list(extended_virial.shape[:-2]) + [9]) # noqa:RUF005
if do_virial:
extended_virial = extended_force.unsqueeze(-1) @ extended_coord.unsqueeze(-2)
# the correction sums to zero, which does not contribute to global virial
if do_atomic_virial:
extended_virial_corr = atomic_virial_corr(extended_coord, atom_energy)
extended_virial = extended_virial + extended_virial_corr
# to [...,3,3] -> [...,9]
extended_virial = extended_virial.view(list(extended_virial.shape[:-2]) + [9]) # noqa:RUF005
else:
extended_virial = None

Check warning on line 78 in deepmd/pt/model/model/transform_output.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/transform_output.py#L78

Added line #L78 was not covered by tests
return extended_force, extended_virial


Expand All @@ -97,6 +101,7 @@
svv: torch.Tensor,
vdef: OutputVariableDef,
coord_ext: torch.Tensor,
do_virial: bool = False,
do_atomic_virial: bool = False,
):
size = 1
Expand All @@ -110,16 +115,25 @@
for vvi, svvi in zip(split_vv1, split_svv1):
# nf x nloc x 3, nf x nloc x 9
ffi, aviri = task_deriv_one(
vvi, svvi, coord_ext, do_atomic_virial=do_atomic_virial
vvi,
svvi,
coord_ext,
do_virial=do_virial,
do_atomic_virial=do_atomic_virial,
)
# nf x nloc x 1 x 3, nf x nloc x 1 x 9
ffi = ffi.unsqueeze(-2)
aviri = aviri.unsqueeze(-2)
split_ff.append(ffi)
split_avir.append(aviri)
if do_virial:
assert aviri is not None
aviri = aviri.unsqueeze(-2)
split_avir.append(aviri)
# nf x nloc x v_dim x 3, nf x nloc x v_dim x 9
ff = torch.concat(split_ff, dim=-2)
avir = torch.concat(split_avir, dim=-2)
if do_virial:
avir = torch.concat(split_avir, dim=-2)
else:
avir = None

Check warning on line 136 in deepmd/pt/model/model/transform_output.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/transform_output.py#L136

Added line #L136 was not covered by tests
return ff, avir


Expand All @@ -141,18 +155,23 @@
if vdef.reduciable:
kk_redu = get_reduce_name(kk)
model_ret[kk_redu] = torch.sum(vv, dim=atom_axis)
if vdef.differentiable:
if vdef.r_differentiable:
kk_derv_r, kk_derv_c = get_deriv_name(kk)
dr, dc = take_deriv(
vv,
model_ret[kk_redu],
vdef,
coord_ext,
do_virial=vdef.c_differentiable,
do_atomic_virial=do_atomic_virial,
)
model_ret[kk_derv_r] = dr
model_ret[kk_derv_c] = dc
model_ret[kk_derv_c + "_redu"] = torch.sum(model_ret[kk_derv_c], dim=1)
if vdef.c_differentiable:
assert dc is not None
model_ret[kk_derv_c] = dc
model_ret[kk_derv_c + "_redu"] = torch.sum(
model_ret[kk_derv_c], dim=1
)
return model_ret


Expand All @@ -174,12 +193,12 @@
if vdef.reduciable:
kk_redu = get_reduce_name(kk)
new_ret[kk_redu] = model_ret[kk_redu]
if vdef.differentiable:
# nf x nloc
vldims = get_leading_dims(vv, vdef)
# nf x nall
mldims = list(mapping.shape)
kk_derv_r, kk_derv_c = get_deriv_name(kk)
# nf x nloc
vldims = get_leading_dims(vv, vdef)
# nf x nall
mldims = list(mapping.shape)
kk_derv_r, kk_derv_c = get_deriv_name(kk)
if vdef.r_differentiable:
# vdim x 3
derv_r_ext_dims = list(vdef.shape) + [3] # noqa:RUF005
mapping = mapping.view(mldims + [1] * len(derv_r_ext_dims)).expand(
Expand All @@ -196,10 +215,13 @@
src=model_ret[kk_derv_r],
reduce="sum",
)
if vdef.c_differentiable:
assert vdef.r_differentiable
derv_c_ext_dims = list(vdef.shape) + [9] # noqa:RUF005
# nf x nloc x nvar x 3 -> nf x nloc x nvar x 9
mapping = torch.tile(
mapping, [1] * (len(mldims) + len(vdef.shape)) + [3]
mapping,
[1] * (len(mldims) + len(vdef.shape)) + [3],
)
virial = torch.zeros(
vldims + derv_c_ext_dims, dtype=vv.dtype, device=vv.device
Expand Down
12 changes: 10 additions & 2 deletions deepmd/pt/model/task/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,18 @@ def output_def(self):
return FittingOutputDef(
[
OutputVariableDef(
"updated_coord", [3], reduciable=False, differentiable=False
"updated_coord",
[3],
reduciable=False,
r_differentiable=False,
c_differentiable=False,
),
OutputVariableDef(
"logits", [-1], reduciable=False, differentiable=False
"logits",
[-1],
reduciable=False,
r_differentiable=False,
c_differentiable=False,
),
]
)
Expand Down
20 changes: 17 additions & 3 deletions deepmd/pt/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,11 @@ def output_def(self) -> FittingOutputDef:
return FittingOutputDef(
[
OutputVariableDef(
self.var_name, [self.dim_out], reduciable=True, differentiable=True
self.var_name,
[self.dim_out],
reduciable=True,
r_differentiable=True,
c_differentiable=True,
),
]
)
Expand Down Expand Up @@ -459,9 +463,19 @@ def __init__(
def output_def(self):
return FittingOutputDef(
[
OutputVariableDef("energy", [1], reduciable=True, differentiable=False),
OutputVariableDef(
"dforce", [3], reduciable=False, differentiable=False
"energy",
[1],
reduciable=True,
r_differentiable=False,
c_differentiable=False,
),
OutputVariableDef(
"dforce",
[3],
reduciable=False,
r_differentiable=False,
c_differentiable=False,
),
]
)
Expand Down
Loading