Skip to content

Commit

Permalink
resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ChiahsinChu committed Oct 24, 2024
1 parent 044023e commit 5cdef87
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 18 deletions.
12 changes: 7 additions & 5 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,15 +155,17 @@ def __init__(
self.fparam_inv_std = np.ones(self.numb_fparam) # pylint: disable=no-explicit-dtype
else:
self.fparam_avg, self.fparam_inv_std = None, None
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
if self.numb_aparam > 0:
self.aparam_avg = np.zeros(self.numb_aparam) # pylint: disable=no-explicit-dtype
self.aparam_inv_std = np.ones(self.numb_aparam) # pylint: disable=no-explicit-dtype
else:
self.aparam_avg, self.aparam_inv_std = None, None
# init networks
in_dim = self.dim_descrpt + self.numb_fparam
if not self.use_aparam_as_mask:
in_dim += self.numb_aparam
in_dim = (
self.dim_descrpt
+ self.numb_fparam
+ (0 if self.use_aparam_as_mask else self.numb_aparam)
)
self.nets = NetworkCollection(
1 if not self.mixed_types else 0,
self.ntypes,
Expand Down Expand Up @@ -391,7 +393,7 @@ def _call_common(
axis=-1,
)
# check aparam dim, concate to input descriptor
if not self.use_aparam_as_mask and self.numb_aparam > 0:
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
assert aparam is not None, "aparam should not be None"
if aparam.shape[-1] != self.numb_aparam:
raise ValueError(
Expand Down
12 changes: 7 additions & 5 deletions deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def __init__(
)
else:
self.fparam_avg, self.fparam_inv_std = None, None
if not self.use_aparam_as_mask and self.numb_aparam > 0:
if self.numb_aparam > 0:
self.register_buffer(
"aparam_avg",
torch.zeros(self.numb_aparam, dtype=self.prec, device=device),
Expand All @@ -210,9 +210,11 @@ def __init__(
else:
self.aparam_avg, self.aparam_inv_std = None, None

in_dim = self.dim_descrpt + self.numb_fparam
if not self.use_aparam_as_mask:
in_dim += self.numb_aparam
in_dim = (
self.dim_descrpt
+ self.numb_fparam
+ (0 if self.use_aparam_as_mask else self.numb_aparam)
)

self.filter_layers = NetworkCollection(
1 if not self.mixed_types else 0,
Expand Down Expand Up @@ -444,7 +446,7 @@ def _forward_common(
dim=-1,
)
# check aparam dim, concate to input descriptor
if not self.use_aparam_as_mask and self.numb_aparam > 0:
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
assert aparam is not None, "aparam should not be None"
assert self.aparam_avg is not None
assert self.aparam_inv_std is not None
Expand Down
17 changes: 9 additions & 8 deletions deepmd/tf/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def compute_input_stats(self, all_stat: dict, protection: float = 1e-2) -> None:
self.fparam_std[ii] = protection
self.fparam_inv_std = 1.0 / self.fparam_std
# stat aparam
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
if self.numb_aparam > 0:
sys_sumv = []
sys_sumv2 = []
sys_sumn = []
Expand Down Expand Up @@ -505,7 +505,7 @@ def build(
self.fparam_avg = 0.0
if self.fparam_inv_std is None:
self.fparam_inv_std = 1.0
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
if self.numb_aparam > 0:
if self.aparam_avg is None:
self.aparam_avg = 0.0
if self.aparam_inv_std is None:
Expand Down Expand Up @@ -561,7 +561,7 @@ def build(
trainable=False,
initializer=tf.constant_initializer(self.fparam_inv_std),
)
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
if self.numb_aparam > 0:
t_aparam_avg = tf.get_variable(
"t_aparam_avg",
self.numb_aparam,
Expand Down Expand Up @@ -602,7 +602,7 @@ def build(
fparam = (fparam - t_fparam_avg) * t_fparam_istd

aparam = None
if not self.use_aparam_as_mask and self.numb_aparam > 0:
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
aparam = input_dict["aparam"]
aparam = tf.reshape(aparam, [-1, self.numb_aparam])
aparam = (aparam - t_aparam_avg) * t_aparam_istd
Expand Down Expand Up @@ -895,9 +895,6 @@ def serialize(self, suffix: str = "") -> dict:
dict
The serialized data
"""
in_dim = self.dim_descrpt + self.numb_fparam
if not self.use_aparam_as_mask:
in_dim += self.numb_aparam
data = {
"@class": "Fitting",
"type": "ener",
Expand All @@ -924,7 +921,11 @@ def serialize(self, suffix: str = "") -> dict:
"nets": self.serialize_network(
ntypes=self.ntypes,
ndim=0 if self.mixed_types else 1,
in_dim=in_dim,
in_dim=(
self.dim_descrpt
+ self.numb_fparam
+ (0 if self.use_aparam_as_mask else self.numb_aparam)
),
neuron=self.n_neuron,
activation_function=self.activation_function_name,
resnet_dt=self.resnet_dt,
Expand Down

0 comments on commit 5cdef87

Please sign in to comment.