diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index 62aafc6207..25d15b2e75 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -173,7 +173,11 @@ def __init__( else: self.aparam_avg, self.aparam_inv_std = None, None # init networks - in_dim = self.dim_descrpt + self.numb_fparam + self.numb_aparam + in_dim = ( + self.dim_descrpt + + self.numb_fparam + + (0 if self.use_aparam_as_mask else self.numb_aparam) + ) self.nets = NetworkCollection( 1 if not self.mixed_types else 0, self.ntypes, @@ -401,7 +405,7 @@ def _call_common( axis=-1, ) # check aparam dim, concate to input descriptor - if self.numb_aparam > 0: + if self.numb_aparam > 0 and not self.use_aparam_as_mask: assert aparam is not None, "aparam should not be None" if aparam.shape[-1] != self.numb_aparam: raise ValueError( diff --git a/deepmd/dpmodel/fitting/invar_fitting.py b/deepmd/dpmodel/fitting/invar_fitting.py index 893853bb38..2a251834fe 100644 --- a/deepmd/dpmodel/fitting/invar_fitting.py +++ b/deepmd/dpmodel/fitting/invar_fitting.py @@ -139,10 +139,6 @@ def __init__( raise NotImplementedError("tot_ener_zero is not implemented") if spin is not None: raise NotImplementedError("spin is not implemented") - if use_aparam_as_mask: - raise NotImplementedError("use_aparam_as_mask is not implemented") - if use_aparam_as_mask: - raise NotImplementedError("use_aparam_as_mask is not implemented") if layer_name is not None: raise NotImplementedError("layer_name is not implemented") diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 6e9829e4b6..15837aca98 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -126,6 +126,8 @@ class GeneralFitting(Fitting): length as `ntypes` signaling if or not removing the vaccum contribution for the atom types in the list. type_map: list[str], Optional A list of strings. Give the name to each type of atoms. + use_aparam_as_mask: bool + If True, the aparam will not be used in fitting net for embedding. """ def __init__( @@ -147,6 +149,7 @@ def __init__( trainable: Union[bool, list[bool]] = True, remove_vaccum_contribution: Optional[list[bool]] = None, type_map: Optional[list[str]] = None, + use_aparam_as_mask: bool = False, **kwargs, ): super().__init__() @@ -164,6 +167,7 @@ def __init__( self.rcond = rcond self.seed = seed self.type_map = type_map + self.use_aparam_as_mask = use_aparam_as_mask # order matters, should be place after the assignment of ntypes self.reinit_exclude(exclude_types) self.trainable = trainable @@ -208,7 +212,11 @@ def __init__( else: self.aparam_avg, self.aparam_inv_std = None, None - in_dim = self.dim_descrpt + self.numb_fparam + self.numb_aparam + in_dim = ( + self.dim_descrpt + + self.numb_fparam + + (0 if self.use_aparam_as_mask else self.numb_aparam) + ) self.filter_layers = NetworkCollection( 1 if not self.mixed_types else 0, @@ -293,13 +301,12 @@ def serialize(self) -> dict: # "trainable": self.trainable , # "atom_ener": self.atom_ener , # "layer_name": self.layer_name , - # "use_aparam_as_mask": self.use_aparam_as_mask , # "spin": self.spin , ## NOTICE: not supported by far "tot_ener_zero": False, "trainable": [self.trainable] * (len(self.neuron) + 1), "layer_name": None, - "use_aparam_as_mask": False, + "use_aparam_as_mask": self.use_aparam_as_mask, "spin": None, } @@ -441,7 +448,7 @@ def _forward_common( dim=-1, ) # check aparam dim, concate to input descriptor - if self.numb_aparam > 0: + if self.numb_aparam > 0 and not self.use_aparam_as_mask: assert aparam is not None, "aparam should not be None" assert self.aparam_avg is not None assert self.aparam_inv_std is not None diff --git a/deepmd/pt/model/task/invar_fitting.py b/deepmd/pt/model/task/invar_fitting.py index 230046b74b..e76e1d2063 100644 --- a/deepmd/pt/model/task/invar_fitting.py +++ b/deepmd/pt/model/task/invar_fitting.py @@ -77,7 +77,8 @@ class InvarFitting(GeneralFitting): The `set_davg_zero` key in the descrptor should be set. type_map: list[str], Optional A list of strings. Give the name to each type of atoms. - + use_aparam_as_mask: bool + If True, the aparam will not be used in fitting net for embedding. """ def __init__( @@ -99,6 +100,7 @@ def __init__( exclude_types: list[int] = [], atom_ener: Optional[list[Optional[torch.Tensor]]] = None, type_map: Optional[list[str]] = None, + use_aparam_as_mask: bool = False, **kwargs, ): self.dim_out = dim_out @@ -122,6 +124,7 @@ def __init__( if atom_ener is None or len([x for x in atom_ener if x is not None]) == 0 else [x is not None for x in atom_ener], type_map=type_map, + use_aparam_as_mask=use_aparam_as_mask, **kwargs, ) diff --git a/deepmd/tf/fit/ener.py b/deepmd/tf/fit/ener.py index b01574cf87..330ea57179 100644 --- a/deepmd/tf/fit/ener.py +++ b/deepmd/tf/fit/ener.py @@ -384,7 +384,7 @@ def _build_lower( ext_fparam = tf.reshape(ext_fparam, [-1, self.numb_fparam]) ext_fparam = tf.cast(ext_fparam, self.fitting_precision) layer = tf.concat([layer, ext_fparam], axis=1) - if aparam is not None: + if aparam is not None and not self.use_aparam_as_mask: ext_aparam = tf.slice( aparam, [0, start_index * self.numb_aparam], @@ -561,7 +561,7 @@ def build( trainable=False, initializer=tf.constant_initializer(self.fparam_inv_std), ) - if self.numb_aparam > 0: + if self.numb_aparam > 0 and not self.use_aparam_as_mask: t_aparam_avg = tf.get_variable( "t_aparam_avg", self.numb_aparam, @@ -576,6 +576,13 @@ def build( trainable=False, initializer=tf.constant_initializer(self.aparam_inv_std), ) + else: + t_aparam_avg = tf.zeros( + self.numb_aparam, dtype=GLOBAL_TF_FLOAT_PRECISION + ) + t_aparam_istd = tf.ones( + self.numb_aparam, dtype=GLOBAL_TF_FLOAT_PRECISION + ) inputs = tf.reshape(inputs, [-1, natoms[0], self.dim_descrpt]) if len(self.atom_ener): @@ -602,12 +609,11 @@ def build( fparam = (fparam - t_fparam_avg) * t_fparam_istd aparam = None - if not self.use_aparam_as_mask: - if self.numb_aparam > 0: - aparam = input_dict["aparam"] - aparam = tf.reshape(aparam, [-1, self.numb_aparam]) - aparam = (aparam - t_aparam_avg) * t_aparam_istd - aparam = tf.reshape(aparam, [-1, self.numb_aparam * natoms[0]]) + if self.numb_aparam > 0 and not self.use_aparam_as_mask: + aparam = input_dict["aparam"] + aparam = tf.reshape(aparam, [-1, self.numb_aparam]) + aparam = (aparam - t_aparam_avg) * t_aparam_istd + aparam = tf.reshape(aparam, [-1, self.numb_aparam * natoms[0]]) atype_nall = tf.reshape(atype, [-1, natoms[1]]) self.atype_nloc = tf.slice( @@ -783,7 +789,7 @@ def init_variables( self.fparam_inv_std = get_tensor_by_name_from_graph( graph, f"fitting_attr{suffix}/t_fparam_istd" ) - if self.numb_aparam > 0: + if self.numb_aparam > 0 and not self.use_aparam_as_mask: self.aparam_avg = get_tensor_by_name_from_graph( graph, f"fitting_attr{suffix}/t_aparam_avg" ) @@ -883,7 +889,7 @@ def deserialize(cls, data: dict, suffix: str = ""): if fitting.numb_fparam > 0: fitting.fparam_avg = data["@variables"]["fparam_avg"] fitting.fparam_inv_std = data["@variables"]["fparam_inv_std"] - if fitting.numb_aparam > 0: + if fitting.numb_aparam > 0 and not fitting.use_aparam_as_mask: fitting.aparam_avg = data["@variables"]["aparam_avg"] fitting.aparam_inv_std = data["@variables"]["aparam_inv_std"] return fitting @@ -922,7 +928,11 @@ def serialize(self, suffix: str = "") -> dict: "nets": self.serialize_network( ntypes=self.ntypes, ndim=0 if self.mixed_types else 1, - in_dim=self.dim_descrpt + self.numb_fparam + self.numb_aparam, + in_dim=( + self.dim_descrpt + + self.numb_fparam + + (0 if self.use_aparam_as_mask else self.numb_aparam) + ), neuron=self.n_neuron, activation_function=self.activation_function_name, resnet_dt=self.resnet_dt, diff --git a/source/tests/consistent/fitting/common.py b/source/tests/consistent/fitting/common.py index bdd4b7cf81..95557d9ab8 100644 --- a/source/tests/consistent/fitting/common.py +++ b/source/tests/consistent/fitting/common.py @@ -18,7 +18,7 @@ class FittingTest: """Useful utilities for descriptor tests.""" - def build_tf_fitting(self, obj, inputs, natoms, atype, fparam, suffix): + def build_tf_fitting(self, obj, inputs, natoms, atype, fparam, aparam, suffix): t_inputs = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="i_inputs") t_natoms = tf.placeholder(tf.int32, natoms.shape, name="i_natoms") t_atype = tf.placeholder(tf.int32, [None], name="i_atype") @@ -30,6 +30,12 @@ def build_tf_fitting(self, obj, inputs, natoms, atype, fparam, suffix): ) extras["fparam"] = t_fparam feed_dict[t_fparam] = fparam + if aparam is not None: + t_aparam = tf.placeholder( + GLOBAL_TF_FLOAT_PRECISION, [None, None], name="i_aparam" + ) + extras["aparam"] = t_aparam + feed_dict[t_aparam] = aparam t_out = obj.build( t_inputs, t_natoms, diff --git a/source/tests/consistent/fitting/test_dos.py b/source/tests/consistent/fitting/test_dos.py index 4a78b69341..774e3f655e 100644 --- a/source/tests/consistent/fitting/test_dos.py +++ b/source/tests/consistent/fitting/test_dos.py @@ -58,6 +58,7 @@ ("float64", "float32"), # precision (True, False), # mixed_types (0, 1), # numb_fparam + (0, 1), # numb_aparam (10, 20), # numb_dos ) class TestDOS(CommonTest, FittingTest, unittest.TestCase): @@ -68,6 +69,7 @@ def data(self) -> dict: precision, mixed_types, numb_fparam, + numb_aparam, numb_dos, ) = self.param return { @@ -75,6 +77,7 @@ def data(self) -> dict: "resnet_dt": resnet_dt, "precision": precision, "numb_fparam": numb_fparam, + "numb_aparam": numb_aparam, "seed": 20240217, "numb_dos": numb_dos, } @@ -86,6 +89,7 @@ def skip_pt(self) -> bool: precision, mixed_types, numb_fparam, + numb_aparam, numb_dos, ) = self.param return CommonTest.skip_pt @@ -115,6 +119,9 @@ def setUp(self): # inconsistent if not sorted self.atype.sort() self.fparam = -np.ones((1,), dtype=GLOBAL_NP_FLOAT_PRECISION) + self.aparam = np.zeros_like( + self.atype, dtype=GLOBAL_NP_FLOAT_PRECISION + ).reshape(-1, 1) @property def addtional_data(self) -> dict: @@ -123,6 +130,7 @@ def addtional_data(self) -> dict: precision, mixed_types, numb_fparam, + numb_aparam, numb_dos, ) = self.param return { @@ -137,6 +145,7 @@ def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: precision, mixed_types, numb_fparam, + numb_aparam, numb_dos, ) = self.param return self.build_tf_fitting( @@ -145,6 +154,7 @@ def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: self.natoms, self.atype, self.fparam if numb_fparam else None, + self.aparam if numb_aparam else None, suffix, ) @@ -154,6 +164,7 @@ def eval_pt(self, pt_obj: Any) -> Any: precision, mixed_types, numb_fparam, + numb_aparam, numb_dos, ) = self.param return ( @@ -163,6 +174,9 @@ def eval_pt(self, pt_obj: Any) -> Any: fparam=torch.from_numpy(self.fparam).to(device=PT_DEVICE) if numb_fparam else None, + aparam=torch.from_numpy(self.aparam).to(device=PT_DEVICE) + if numb_aparam + else None, )["dos"] .detach() .cpu() @@ -175,12 +189,14 @@ def eval_dp(self, dp_obj: Any) -> Any: precision, mixed_types, numb_fparam, + numb_aparam, numb_dos, ) = self.param return dp_obj( self.inputs, self.atype.reshape(1, -1), fparam=self.fparam if numb_fparam else None, + aparam=self.aparam if numb_aparam else None, )["dos"] def eval_jax(self, jax_obj: Any) -> Any: @@ -189,6 +205,7 @@ def eval_jax(self, jax_obj: Any) -> Any: precision, mixed_types, numb_fparam, + numb_aparam, numb_dos, ) = self.param return np.asarray( @@ -196,6 +213,7 @@ def eval_jax(self, jax_obj: Any) -> Any: jnp.asarray(self.inputs), jnp.asarray(self.atype.reshape(1, -1)), fparam=jnp.asarray(self.fparam) if numb_fparam else None, + aparam=jnp.asarray(self.aparam) if numb_aparam else None, )["dos"] ) @@ -206,6 +224,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: precision, mixed_types, numb_fparam, + numb_aparam, numb_dos, ) = self.param return np.asarray( @@ -213,6 +232,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: array_api_strict.asarray(self.inputs), array_api_strict.asarray(self.atype.reshape(1, -1)), fparam=array_api_strict.asarray(self.fparam) if numb_fparam else None, + aparam=array_api_strict.asarray(self.aparam) if numb_aparam else None, )["dos"] ) @@ -230,6 +250,7 @@ def rtol(self) -> float: precision, mixed_types, numb_fparam, + numb_aparam, numb_dos, ) = self.param if precision == "float64": @@ -247,6 +268,7 @@ def atol(self) -> float: precision, mixed_types, numb_fparam, + numb_aparam, numb_dos, ) = self.param if precision == "float64": diff --git a/source/tests/consistent/fitting/test_ener.py b/source/tests/consistent/fitting/test_ener.py index ba2be1d86b..e32410a0ec 100644 --- a/source/tests/consistent/fitting/test_ener.py +++ b/source/tests/consistent/fitting/test_ener.py @@ -60,6 +60,7 @@ ("float64", "float32", "bfloat16"), # precision (True, False), # mixed_types (0, 1), # numb_fparam + ((0, False), (1, False), (1, True)), # (numb_aparam, use_aparam_as_mask) ([], [-12345.6, None]), # atom_ener ) class TestEner(CommonTest, FittingTest, unittest.TestCase): @@ -70,6 +71,7 @@ def data(self) -> dict: precision, mixed_types, numb_fparam, + (numb_aparam, use_aparam_as_mask), atom_ener, ) = self.param return { @@ -77,8 +79,10 @@ def data(self) -> dict: "resnet_dt": resnet_dt, "precision": precision, "numb_fparam": numb_fparam, + "numb_aparam": numb_aparam, "seed": 20240217, "atom_ener": atom_ener, + "use_aparam_as_mask": use_aparam_as_mask, } @property @@ -88,6 +92,7 @@ def skip_pt(self) -> bool: precision, mixed_types, numb_fparam, + (numb_aparam, use_aparam_as_mask), atom_ener, ) = self.param return CommonTest.skip_pt @@ -101,6 +106,7 @@ def skip_array_api_strict(self) -> bool: precision, mixed_types, numb_fparam, + (numb_aparam, use_aparam_as_mask), atom_ener, ) = self.param # TypeError: The array_api_strict namespace does not support the dtype 'bfloat16' @@ -123,6 +129,9 @@ def setUp(self): # inconsistent if not sorted self.atype.sort() self.fparam = -np.ones((1,), dtype=GLOBAL_NP_FLOAT_PRECISION) + self.aparam = np.zeros_like( + self.atype, dtype=GLOBAL_NP_FLOAT_PRECISION + ).reshape(-1, 1) @property def addtional_data(self) -> dict: @@ -131,6 +140,7 @@ def addtional_data(self) -> dict: precision, mixed_types, numb_fparam, + (numb_aparam, use_aparam_as_mask), atom_ener, ) = self.param return { @@ -145,6 +155,7 @@ def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: precision, mixed_types, numb_fparam, + (numb_aparam, use_aparam_as_mask), atom_ener, ) = self.param return self.build_tf_fitting( @@ -153,6 +164,7 @@ def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: self.natoms, self.atype, self.fparam if numb_fparam else None, + self.aparam if numb_aparam else None, suffix, ) @@ -162,15 +174,23 @@ def eval_pt(self, pt_obj: Any) -> Any: precision, mixed_types, numb_fparam, + (numb_aparam, use_aparam_as_mask), atom_ener, ) = self.param return ( pt_obj( torch.from_numpy(self.inputs).to(device=PT_DEVICE), torch.from_numpy(self.atype.reshape(1, -1)).to(device=PT_DEVICE), - fparam=torch.from_numpy(self.fparam).to(device=PT_DEVICE) - if numb_fparam - else None, + fparam=( + torch.from_numpy(self.fparam).to(device=PT_DEVICE) + if numb_fparam + else None + ), + aparam=( + torch.from_numpy(self.aparam).to(device=PT_DEVICE) + if numb_aparam + else None + ), )["energy"] .detach() .cpu() @@ -183,12 +203,14 @@ def eval_dp(self, dp_obj: Any) -> Any: precision, mixed_types, numb_fparam, + (numb_aparam, use_aparam_as_mask), atom_ener, ) = self.param return dp_obj( self.inputs, self.atype.reshape(1, -1), fparam=self.fparam if numb_fparam else None, + aparam=self.aparam if numb_aparam else None, )["energy"] def eval_jax(self, jax_obj: Any) -> Any: @@ -197,6 +219,7 @@ def eval_jax(self, jax_obj: Any) -> Any: precision, mixed_types, numb_fparam, + (numb_aparam, use_aparam_as_mask), atom_ener, ) = self.param return np.asarray( @@ -204,6 +227,7 @@ def eval_jax(self, jax_obj: Any) -> Any: jnp.asarray(self.inputs), jnp.asarray(self.atype.reshape(1, -1)), fparam=jnp.asarray(self.fparam) if numb_fparam else None, + aparam=jnp.asarray(self.aparam) if numb_aparam else None, )["energy"] ) @@ -214,6 +238,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: precision, mixed_types, numb_fparam, + (numb_aparam, use_aparam_as_mask), atom_ener, ) = self.param return np.asarray( @@ -221,6 +246,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: array_api_strict.asarray(self.inputs), array_api_strict.asarray(self.atype.reshape(1, -1)), fparam=array_api_strict.asarray(self.fparam) if numb_fparam else None, + aparam=array_api_strict.asarray(self.aparam) if numb_aparam else None, )["energy"] ) @@ -238,6 +264,7 @@ def rtol(self) -> float: precision, mixed_types, numb_fparam, + (numb_aparam, use_aparam_as_mask), atom_ener, ) = self.param if precision == "float64": @@ -257,6 +284,7 @@ def atol(self) -> float: precision, mixed_types, numb_fparam, + (numb_aparam, use_aparam_as_mask), atom_ener, ) = self.param if precision == "float64": diff --git a/source/tests/consistent/fitting/test_property.py b/source/tests/consistent/fitting/test_property.py index a9fb6b694a..beb21d9c04 100644 --- a/source/tests/consistent/fitting/test_property.py +++ b/source/tests/consistent/fitting/test_property.py @@ -40,6 +40,7 @@ ("float64", "float32"), # precision (True, False), # mixed_types (0, 1), # numb_fparam + (0, 1), # numb_aparam (1, 3), # task_dim (True, False), # intensive ) @@ -51,6 +52,7 @@ def data(self) -> dict: precision, mixed_types, numb_fparam, + numb_aparam, task_dim, intensive, ) = self.param @@ -59,6 +61,7 @@ def data(self) -> dict: "resnet_dt": resnet_dt, "precision": precision, "numb_fparam": numb_fparam, + "numb_aparam": numb_aparam, "seed": 20240217, "task_dim": task_dim, "intensive": intensive, @@ -71,6 +74,7 @@ def skip_pt(self) -> bool: precision, mixed_types, numb_fparam, + numb_aparam, task_dim, intensive, ) = self.param @@ -95,6 +99,9 @@ def setUp(self): # inconsistent if not sorted self.atype.sort() self.fparam = -np.ones((1,), dtype=GLOBAL_NP_FLOAT_PRECISION) + self.aparam = np.zeros_like( + self.atype, dtype=GLOBAL_NP_FLOAT_PRECISION + ).reshape(-1, 1) @property def addtional_data(self) -> dict: @@ -103,6 +110,7 @@ def addtional_data(self) -> dict: precision, mixed_types, numb_fparam, + numb_aparam, task_dim, intensive, ) = self.param @@ -118,6 +126,7 @@ def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: precision, mixed_types, numb_fparam, + numb_aparam, task_dim, intensive, ) = self.param @@ -127,6 +136,7 @@ def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: self.natoms, self.atype, self.fparam if numb_fparam else None, + self.aparam if numb_aparam else None, suffix, ) @@ -136,6 +146,7 @@ def eval_pt(self, pt_obj: Any) -> Any: precision, mixed_types, numb_fparam, + numb_aparam, task_dim, intensive, ) = self.param @@ -146,6 +157,9 @@ def eval_pt(self, pt_obj: Any) -> Any: fparam=torch.from_numpy(self.fparam).to(device=PT_DEVICE) if numb_fparam else None, + aparam=torch.from_numpy(self.aparam).to(device=PT_DEVICE) + if numb_aparam + else None, )["property"] .detach() .cpu() @@ -158,6 +172,7 @@ def eval_dp(self, dp_obj: Any) -> Any: precision, mixed_types, numb_fparam, + numb_aparam, task_dim, intensive, ) = self.param @@ -165,6 +180,7 @@ def eval_dp(self, dp_obj: Any) -> Any: self.inputs, self.atype.reshape(1, -1), fparam=self.fparam if numb_fparam else None, + aparam=self.aparam if numb_aparam else None, )["property"] def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: @@ -181,6 +197,7 @@ def rtol(self) -> float: precision, mixed_types, numb_fparam, + numb_aparam, task_dim, intensive, ) = self.param @@ -199,6 +216,7 @@ def atol(self) -> float: precision, mixed_types, numb_fparam, + numb_aparam, task_dim, intensive, ) = self.param diff --git a/source/tests/pt/model/test_ener_fitting.py b/source/tests/pt/model/test_ener_fitting.py index 5c55766455..acf0a47769 100644 --- a/source/tests/pt/model/test_ener_fitting.py +++ b/source/tests/pt/model/test_ener_fitting.py @@ -36,6 +36,7 @@ def setUp(self): def test_consistency( self, ): + # ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 1600 is different from 1604) rng = np.random.default_rng(GLOBAL_SEED) nf, nloc, nnei = self.nlist.shape dd0 = DescrptSeA(self.rcut, self.rcut_smth, self.sel).to(env.DEVICE) @@ -46,13 +47,14 @@ def test_consistency( ) atype = torch.tensor(self.atype_ext[:, :nloc], dtype=int, device=env.DEVICE) - for od, mixed_types, nfp, nap, et, nn in itertools.product( + for od, mixed_types, nfp, nap, et, nn, use_aparam_as_mask in itertools.product( [1, 3], [True, False], [0, 3], [0, 4], [[], [0], [1]], [[4, 4, 4], []], + [True, False], ): ft0 = InvarFitting( "foo", @@ -65,6 +67,7 @@ def test_consistency( exclude_types=et, neuron=nn, seed=GLOBAL_SEED, + use_aparam_as_mask=use_aparam_as_mask, ).to(env.DEVICE) ft1 = DPInvarFitting.deserialize(ft0.serialize()) ft2 = InvarFitting.deserialize(ft0.serialize()) @@ -105,12 +108,13 @@ def test_consistency( def test_jit( self, ): - for od, mixed_types, nfp, nap, et in itertools.product( + for od, mixed_types, nfp, nap, et, use_aparam_as_mask in itertools.product( [1, 3], [True, False], [0, 3], [0, 4], [[], [0]], + [True, False], ): ft0 = InvarFitting( "foo", @@ -122,6 +126,7 @@ def test_jit( mixed_types=mixed_types, exclude_types=et, seed=GLOBAL_SEED, + use_aparam_as_mask=use_aparam_as_mask, ).to(env.DEVICE) torch.jit.script(ft0) @@ -146,3 +151,38 @@ def test_get_set(self): np.testing.assert_allclose( foo, np.reshape(ifn0[ii].detach().cpu().numpy(), foo.shape) ) + + def test_use_aparam_as_mask(self): + nap = 4 + dd0 = DescrptSeA(self.rcut, self.rcut_smth, self.sel).to(env.DEVICE) + + for od, mixed_types, nfp, et, nn in itertools.product( + [1, 3], + [True, False], + [0, 3], + [[], [0], [1]], + [[4, 4, 4], []], + ): + ft0 = InvarFitting( + "foo", + self.nt, + dd0.dim_out, + od, + numb_fparam=nfp, + numb_aparam=nap, + mixed_types=mixed_types, + exclude_types=et, + neuron=nn, + seed=GLOBAL_SEED, + use_aparam_as_mask=True, + ).to(env.DEVICE) + in_dim = ft0.dim_descrpt + ft0.numb_fparam + assert ft0.filter_layers[0].in_dim == in_dim + + ft1 = DPInvarFitting.deserialize(ft0.serialize()) + in_dim = ft1.dim_descrpt + ft1.numb_fparam + assert ft1.nets[0].in_dim == in_dim + + ft2 = InvarFitting.deserialize(ft0.serialize()) + in_dim = ft2.dim_descrpt + ft2.numb_fparam + assert ft2.filter_layers[0].in_dim == in_dim