diff --git a/deepmd/main.py b/deepmd/main.py index 4560df9e57..6baf91adca 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -659,6 +659,72 @@ def main_parser() -> argparse.ArgumentParser: help="treat all types as a single type. Used with se_atten descriptor.", ) + # change_bias + parser_change_bias = subparsers.add_parser( + "change-bias", + parents=[parser_log], + help="(Supported backend: PyTorch) Change model out bias according to the input data.", + formatter_class=RawTextArgumentDefaultsHelpFormatter, + epilog=textwrap.dedent( + """\ + examples: + dp change-bias model.pt -s data -n 10 -m change + """ + ), + ) + parser_change_bias.add_argument( + "INPUT", help="The input checkpoint file or frozen model file" + ) + parser_change_bias_source = parser_change_bias.add_mutually_exclusive_group() + parser_change_bias_source.add_argument( + "-s", + "--system", + default=".", + type=str, + help="The system dir. Recursively detect systems in this directory", + ) + parser_change_bias_source.add_argument( + "-b", + "--bias-value", + default=None, + type=float, + nargs="+", + help="The user defined value for each type in the type_map of the model, split with spaces.\n" + "For example, '-93.57 -187.1' for energy bias of two elements. " + "Only supports energy bias changing.", + ) + parser_change_bias.add_argument( + "-n", + "--numb-batch", + default=0, + type=int, + help="The number of frames for bias changing in one data system. 0 means all data.", + ) + parser_change_bias.add_argument( + "-m", + "--mode", + type=str, + default="change", + choices=["change", "set"], + help="The mode for changing energy bias: \n" + "change (default) : perform predictions using input model on target dataset, " + "and do least square on the errors to obtain the target shift as bias.\n" + "set : directly use the statistic bias in the target dataset.", + ) + parser_change_bias.add_argument( + "-o", + "--output", + default=None, + type=str, + help="The model after changing bias.", + ) + parser_change_bias.add_argument( + "--model-branch", + type=str, + default=None, + help="Model branch chosen for changing bias if multi-task model.", + ) + # --version parser.add_argument( "--version", action="version", version=f"DeePMD-kit v{__version__}" @@ -831,6 +897,7 @@ def main(): "convert-from", "train-nvnmd", "show", + "change-bias", ): deepmd_main = BACKENDS[args.backend]().entry_point_hook elif args.command is None: diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index c28fa02e70..f5e7db8aa8 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import argparse +import copy import json import logging import os @@ -23,6 +24,9 @@ from deepmd import ( __version__, ) +from deepmd.common import ( + expand_sys_str, +) from deepmd.env import ( GLOBAL_CONFIG, ) @@ -44,6 +48,9 @@ from deepmd.pt.train import ( training, ) +from deepmd.pt.train.wrapper import ( + ModelWrapper, +) from deepmd.pt.utils import ( env, ) @@ -59,6 +66,12 @@ from deepmd.pt.utils.multi_task import ( preprocess_shared_params, ) +from deepmd.pt.utils.stat import ( + make_stat_input, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, +) from deepmd.utils.argcheck import ( normalize, ) @@ -377,6 +390,128 @@ def show(FLAGS): log.info(f"The fitting_net parameter is {fitting_net}") +def change_bias(FLAGS): + if FLAGS.INPUT.endswith(".pt"): + old_state_dict = torch.load(FLAGS.INPUT, map_location=env.DEVICE) + model_state_dict = copy.deepcopy(old_state_dict.get("model", old_state_dict)) + model_params = model_state_dict["_extra_state"]["model_params"] + elif FLAGS.INPUT.endswith(".pth"): + old_model = torch.jit.load(FLAGS.INPUT, map_location=env.DEVICE) + model_params_string = old_model.get_model_def_script() + model_params = json.loads(model_params_string) + old_state_dict = old_model.state_dict() + model_state_dict = old_state_dict + else: + raise RuntimeError( + "The model provided must be a checkpoint file with a .pt extension " + "or a frozen model with a .pth extension" + ) + multi_task = "model_dict" in model_params + model_branch = FLAGS.model_branch + bias_adjust_mode = ( + "change-by-statistic" if FLAGS.mode == "change" else "set-by-statistic" + ) + if multi_task: + assert ( + model_branch is not None + ), "For multitask model, the model branch must be set!" + assert model_branch in model_params["model_dict"], ( + f"For multitask model, the model branch must be in the 'model_dict'! " + f"Available options are : {list(model_params['model_dict'].keys())}." + ) + log.info(f"Changing out bias for model {model_branch}.") + model = training.get_model_for_wrapper(model_params) + type_map = ( + model_params["type_map"] + if not multi_task + else model_params["model_dict"][model_branch]["type_map"] + ) + model_to_change = model if not multi_task else model[model_branch] + if FLAGS.INPUT.endswith(".pt"): + wrapper = ModelWrapper(model) + wrapper.load_state_dict(old_state_dict["model"]) + else: + # for .pth + model.load_state_dict(old_state_dict) + + if FLAGS.bias_value is not None: + # use user-defined bias + assert model_to_change.model_type in [ + "ener" + ], "User-defined bias is only available for energy model!" + assert ( + len(FLAGS.bias_value) == len(type_map) + ), f"The number of elements in the bias should be the same as that in the type_map: {type_map}." + old_bias = model_to_change.get_out_bias() + bias_to_set = torch.tensor( + FLAGS.bias_value, dtype=old_bias.dtype, device=old_bias.device + ).view(old_bias.shape) + model_to_change.set_out_bias(bias_to_set) + log.info( + f"Change output bias of {type_map!s} " + f"from {to_numpy_array(old_bias).reshape(-1)!s} " + f"to {to_numpy_array(bias_to_set).reshape(-1)!s}." + ) + updated_model = model_to_change + else: + # calculate bias on given systems + data_systems = process_systems(expand_sys_str(FLAGS.system)) + data_single = DpLoaderSet( + data_systems, + 1, + type_map, + ) + mock_loss = training.get_loss( + {"inference": True}, 1.0, len(type_map), model_to_change + ) + data_requirement = mock_loss.label_requirement + data_requirement += training.get_additional_data_requirement(model_to_change) + data_single.add_data_requirement(data_requirement) + nbatches = FLAGS.numb_batch if FLAGS.numb_batch != 0 else float("inf") + sampled_data = make_stat_input( + data_single.systems, + data_single.dataloaders, + nbatches, + ) + updated_model = training.model_change_out_bias( + model_to_change, sampled_data, _bias_adjust_mode=bias_adjust_mode + ) + + if not multi_task: + model = updated_model + else: + model[model_branch] = updated_model + + if FLAGS.INPUT.endswith(".pt"): + output_path = ( + FLAGS.output + if FLAGS.output is not None + else FLAGS.INPUT.replace(".pt", "_updated.pt") + ) + wrapper = ModelWrapper(model) + if "model" in old_state_dict: + old_state_dict["model"] = wrapper.state_dict() + old_state_dict["model"]["_extra_state"] = model_state_dict["_extra_state"] + else: + old_state_dict = wrapper.state_dict() + old_state_dict["_extra_state"] = model_state_dict["_extra_state"] + torch.save(old_state_dict, output_path) + else: + # for .pth + output_path = ( + FLAGS.output + if FLAGS.output is not None + else FLAGS.INPUT.replace(".pth", "_updated.pth") + ) + model = torch.jit.script(model) + torch.jit.save( + model, + output_path, + {}, + ) + log.info(f"Saved model to {output_path}") + + @record def main(args: Optional[Union[List[str], argparse.Namespace]] = None): if not isinstance(args, argparse.Namespace): @@ -401,6 +536,8 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None): freeze(FLAGS) elif FLAGS.command == "show": show(FLAGS) + elif FLAGS.command == "change-bias": + change_bias(FLAGS) else: raise RuntimeError(f"Invalid command {FLAGS.command}!") diff --git a/deepmd/pt/loss/ener.py b/deepmd/pt/loss/ener.py index 97e329935a..092fbc1f76 100644 --- a/deepmd/pt/loss/ener.py +++ b/deepmd/pt/loss/ener.py @@ -96,7 +96,7 @@ def __init__( self.has_v = (start_pref_v != 0.0 and limit_pref_v != 0.0) or inference self.has_ae = (start_pref_ae != 0.0 and limit_pref_ae != 0.0) or inference self.has_pf = (start_pref_pf != 0.0 and limit_pref_pf != 0.0) or inference - self.has_gf = (start_pref_gf != 0.0 and limit_pref_gf != 0.0) or inference + self.has_gf = start_pref_gf != 0.0 and limit_pref_gf != 0.0 self.start_pref_e = start_pref_e self.limit_pref_e = limit_pref_e diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index 72f95c6d49..6a42393310 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -103,6 +103,9 @@ def init_out_stat(self): self.register_buffer("out_bias", out_bias_data) self.register_buffer("out_std", out_std_data) + def set_out_bias(self, out_bias: torch.Tensor) -> None: + self.out_bias = out_bias + def __setitem__(self, key, value): if key in ["out_bias"]: self.out_bias = value diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 38fa0e2530..32432725d3 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -175,6 +175,9 @@ def forward_common( def get_out_bias(self) -> torch.Tensor: return self.atomic_model.get_out_bias() + def set_out_bias(self, out_bias: torch.Tensor) -> None: + self.atomic_model.set_out_bias(out_bias) + def change_out_bias( self, merged, diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 9d5c9ea51e..e097d2a8b2 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -141,6 +141,9 @@ def __init__( self.max_ckpt_keep = training_params.get("max_ckpt_keep", 5) self.display_in_training = training_params.get("disp_training", True) self.timing_in_training = training_params.get("time_training", True) + self.change_bias_after_training = training_params.get( + "change_bias_after_training", False + ) self.lcurve_should_print_header = True def get_opt_param(params): @@ -220,28 +223,7 @@ def single_model_stat( _data_requirement, finetune_has_new_type=False, ): - if _model.get_dim_fparam() > 0: - fparam_requirement_items = [ - DataRequirementItem( - "fparam", _model.get_dim_fparam(), atomic=False, must=True - ) - ] - _data_requirement += fparam_requirement_items - if _model.get_dim_aparam() > 0: - aparam_requirement_items = [ - DataRequirementItem( - "aparam", _model.get_dim_aparam(), atomic=True, must=True - ) - ] - _data_requirement += aparam_requirement_items - has_spin = getattr(_model, "has_spin", False) - if callable(has_spin): - has_spin = has_spin() - if has_spin: - spin_requirement_items = [ - DataRequirementItem("spin", ndof=3, atomic=True, must=True) - ] - _data_requirement += spin_requirement_items + _data_requirement += get_additional_data_requirement(_model) _training_data.add_data_requirement(_data_requirement) if _validation_data is not None: _validation_data.add_data_requirement(_data_requirement) @@ -264,15 +246,6 @@ def get_sample(): _stat_file_path.root.close() return get_sample - def get_single_model( - _model_params, - ): - if "use_srtab" in _model_params: - model = get_zbl_model(deepcopy(_model_params)).to(DEVICE) - else: - model = get_model(deepcopy(_model_params)).to(DEVICE) - return model - def get_lr(lr_params): assert ( lr_params.get("type", "exp") == "exp" @@ -281,39 +254,6 @@ def get_lr(lr_params): lr_exp = LearningRateExp(**lr_params) return lr_exp - def get_loss(loss_params, start_lr, _ntypes, _model): - loss_type = loss_params.get("type", "ener") - if loss_type == "ener": - loss_params["starter_learning_rate"] = start_lr - return EnergyStdLoss(**loss_params) - elif loss_type == "dos": - loss_params["starter_learning_rate"] = start_lr - loss_params["numb_dos"] = _model.model_output_def()["dos"].output_size - return DOSLoss(**loss_params) - elif loss_type == "ener_spin": - loss_params["starter_learning_rate"] = start_lr - return EnergySpinLoss(**loss_params) - elif loss_type == "denoise": - loss_params["ntypes"] = _ntypes - return DenoiseLoss(**loss_params) - elif loss_type == "tensor": - model_output_type = _model.model_output_type() - if "mask" in model_output_type: - model_output_type.pop(model_output_type.index("mask")) - tensor_name = model_output_type[0] - loss_params["tensor_name"] = tensor_name - loss_params["tensor_size"] = _model.model_output_def()[ - tensor_name - ].output_size - label_name = tensor_name - if label_name == "polarizability": - label_name = "polar" - loss_params["label_name"] = label_name - loss_params["tensor_name"] = label_name - return TensorLoss(**loss_params) - else: - raise NotImplementedError - # Optimizer if self.multi_task and training_params.get("optim_dict", None) is not None: self.optim_dict = training_params.get("optim_dict") @@ -337,20 +277,6 @@ def get_loss(loss_params, start_lr, _ntypes, _model): if training_params["seed"] is not None: torch.manual_seed(training_params["seed"]) - def get_model_for_wrapper(_model_params): - if "model_dict" not in _model_params: - _model = get_single_model( - _model_params, - ) - else: - _model = {} - model_keys = list(_model_params["model_dict"]) - for _model_key in model_keys: - _model[_model_key] = get_single_model( - _model_params["model_dict"][_model_key], - ) - return _model - self.model = get_model_for_wrapper(model_params) # Loss @@ -600,7 +526,7 @@ def single_model_finetune( _finetune_rule_single, _sample_func, ): - _model = _model_change_out_bias( + _model = model_change_out_bias( _model, _sample_func, _bias_adjust_mode="change-by-statistic" @@ -1019,6 +945,28 @@ def log_loss_valid(_task_key="Default"): if JIT: break + if self.change_bias_after_training and (self.rank == 0 or dist.get_rank() == 0): + if not self.multi_task: + self.model = model_change_out_bias( + self.model, + self.get_sample_func, + _bias_adjust_mode="change-by-statistic", + ) + else: + for model_key in self.model_keys: + self.model[model_key] = model_change_out_bias( + self.model[model_key], + self.get_sample_func[model_key], + _bias_adjust_mode="change-by-statistic", + ) + self.latest_model = Path(self.save_ckpt + f"-{self.num_steps}.pt") + cur_lr = self.lr_exp.value(self.num_steps - 1) + self.save_model(self.latest_model, lr=cur_lr, step=self.num_steps - 1) + log.info(f"Saved model to {self.latest_model}") + symlink_prefix_files(self.latest_model.stem, self.save_ckpt) + with open("checkpoint", "w") as f: + f.write(str(self.latest_model)) + if ( self.rank == 0 or dist.get_rank() == 0 ): # Handle the case if rank 0 aborted and re-assigned @@ -1234,17 +1182,101 @@ def print_on_training(self, fout, step_id, cur_lr, train_results, valid_results) fout.flush() -def _model_change_out_bias( +def get_additional_data_requirement(_model): + additional_data_requirement = [] + if _model.get_dim_fparam() > 0: + fparam_requirement_items = [ + DataRequirementItem( + "fparam", _model.get_dim_fparam(), atomic=False, must=True + ) + ] + additional_data_requirement += fparam_requirement_items + if _model.get_dim_aparam() > 0: + aparam_requirement_items = [ + DataRequirementItem( + "aparam", _model.get_dim_aparam(), atomic=True, must=True + ) + ] + additional_data_requirement += aparam_requirement_items + has_spin = getattr(_model, "has_spin", False) + if callable(has_spin): + has_spin = has_spin() + if has_spin: + spin_requirement_items = [ + DataRequirementItem("spin", ndof=3, atomic=True, must=True) + ] + additional_data_requirement += spin_requirement_items + return additional_data_requirement + + +def get_loss(loss_params, start_lr, _ntypes, _model): + loss_type = loss_params.get("type", "ener") + if loss_type == "ener": + loss_params["starter_learning_rate"] = start_lr + return EnergyStdLoss(**loss_params) + elif loss_type == "dos": + loss_params["starter_learning_rate"] = start_lr + loss_params["numb_dos"] = _model.model_output_def()["dos"].output_size + return DOSLoss(**loss_params) + elif loss_type == "ener_spin": + loss_params["starter_learning_rate"] = start_lr + return EnergySpinLoss(**loss_params) + elif loss_type == "denoise": + loss_params["ntypes"] = _ntypes + return DenoiseLoss(**loss_params) + elif loss_type == "tensor": + model_output_type = _model.model_output_type() + if "mask" in model_output_type: + model_output_type.pop(model_output_type.index("mask")) + tensor_name = model_output_type[0] + loss_params["tensor_name"] = tensor_name + loss_params["tensor_size"] = _model.model_output_def()[tensor_name].output_size + label_name = tensor_name + if label_name == "polarizability": + label_name = "polar" + loss_params["label_name"] = label_name + loss_params["tensor_name"] = label_name + return TensorLoss(**loss_params) + else: + raise NotImplementedError + + +def get_single_model( + _model_params, +): + if "use_srtab" in _model_params: + model = get_zbl_model(deepcopy(_model_params)).to(DEVICE) + else: + model = get_model(deepcopy(_model_params)).to(DEVICE) + return model + + +def get_model_for_wrapper(_model_params): + if "model_dict" not in _model_params: + _model = get_single_model( + _model_params, + ) + else: + _model = {} + model_keys = list(_model_params["model_dict"]) + for _model_key in model_keys: + _model[_model_key] = get_single_model( + _model_params["model_dict"][_model_key], + ) + return _model + + +def model_change_out_bias( _model, _sample_func, _bias_adjust_mode="change-by-statistic", ): - old_bias = _model.get_out_bias() + old_bias = deepcopy(_model.get_out_bias()) _model.change_out_bias( _sample_func, bias_adjust_mode=_bias_adjust_mode, ) - new_bias = _model.get_out_bias() + new_bias = deepcopy(_model.get_out_bias()) model_type_map = _model.get_type_map() log.info( diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index b0bbda5dbe..3ba159b7ce 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -53,7 +53,8 @@ def make_stat_input(datasets, dataloaders, nbatches): sys_stat = {} with torch.device("cpu"): iterator = iter(dataloaders[i]) - for _ in range(nbatches): + numb_batches = min(nbatches, len(dataloaders[i])) + for _ in range(numb_batches): try: stat_data = next(iterator) except StopIteration: diff --git a/deepmd/tf/train/trainer.py b/deepmd/tf/train/trainer.py index 60a468be3e..474af1da90 100644 --- a/deepmd/tf/train/trainer.py +++ b/deepmd/tf/train/trainer.py @@ -145,6 +145,9 @@ def get_lr_and_coef(lr_param): self.tensorboard_log_dir = tr_data.get("tensorboard_log_dir", "log") self.tensorboard_freq = tr_data.get("tensorboard_freq", 1) self.mixed_prec = tr_data.get("mixed_precision", None) + self.change_bias_after_training = tr_data.get( + "change_bias_after_training", False + ) if self.mixed_prec is not None: if ( self.mixed_prec["compute_prec"] not in ("float16", "bfloat16") @@ -563,6 +566,32 @@ def train(self, train_data=None, valid_data=None): and self.saver is not None ): self.save_checkpoint(cur_batch) + if self.change_bias_after_training: + import tempfile + + from deepmd.tf.entrypoints import ( + freeze, + ) + + self.save_checkpoint(cur_batch) + with tempfile.NamedTemporaryFile(suffix=".pb") as f: + freeze( + checkpoint_folder=os.path.join(os.getcwd(), self.save_ckpt), + output=f.name, + ) + self._change_energy_bias( + train_data, + f.name, + self.type_map, + bias_adjust_mode="change-by-statistic", + ) + assign_op = tf.assign( + self.model.get_fitting().t_bias_atom_e, + self.model.get_fitting().bias_atom_e, + ) + run_sess(self.sess, assign_op) + self.save_checkpoint(cur_batch) + if ( self.save_freq == 0 or cur_batch == 0 or cur_batch % self.save_freq != 0 ) and self.saver is not None: diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index d34726e7b1..b10244d436 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2337,6 +2337,11 @@ def training_args(): # ! modified by Ziyao: data configuration isolated. "The oldest checkpoints will be deleted once the number of checkpoints exceeds max_ckpt_keep. " "Defaults to 5." ) + doc_change_bias_after_training = ( + "Whether to change the output bias after the last training step, " + "by performing predictions using trained model on training data and " + "doing least square on the errors to add the target shift on the bias." + ) doc_disp_training = "Displaying verbose information during training." doc_time_training = "Timing durining training." doc_profiling = "Export the profiling results to the Chrome JSON file for performance analysis, driven by the legacy TensorFlow profiling API or PyTorch Profiler. The output file will be saved to `profiling_file`." @@ -2386,6 +2391,13 @@ def training_args(): # ! modified by Ziyao: data configuration isolated. "save_ckpt", str, optional=True, default="model.ckpt", doc=doc_save_ckpt ), Argument("max_ckpt_keep", int, optional=True, default=5, doc=doc_max_ckpt_keep), + Argument( + "change_bias_after_training", + bool, + optional=True, + default=False, + doc=doc_change_bias_after_training, + ), Argument( "disp_training", bool, optional=True, default=True, doc=doc_disp_training ), diff --git a/doc/model/change-bias.md b/doc/model/change-bias.md new file mode 100644 index 0000000000..ac28201cb6 --- /dev/null +++ b/doc/model/change-bias.md @@ -0,0 +1,42 @@ +# Change the model output bias for trained model {{ pytorch_icon }} + +:::{note} +**Supported backends**: PyTorch {{ pytorch_icon }} +::: + +The output bias of a trained model typically originates from the statistical results of the training dataset. + +There are several scenarios where one might want to adjust the output bias after the model is trained, +such as zero-shot testing (similar to the procedure before the first step in fine-tuning) +or manually setting the output bias. + +The `dp --pt change-bias` command supports the following methods for adjusting the bias: + +::::{tab-set} + +:::{tab-item} Changing bias using provided systems for trained `.pt`/`.pth` models: + +```sh +dp --pt change-bias model.pt -s data_dir -o model_updated.pt +``` + +For multitask models, where `--model-branch` must be specified: + +```sh +dp --pt change-bias multi_model.pt -s data_dir -o model_updated.pt --model-branch model_1 +``` + +::: + +:::{tab-item} Changing bias using user input for **energy model**: + +```sh +dp --pt change-bias model.pt -b -92.523 -187.66 -o model_updated.pt +``` + +Here, `-b` specifies user-defined energy bias for each type, separated by space, +in an order consistent with the `type_map` in the model. + +::: + +:::: diff --git a/doc/model/index.rst b/doc/model/index.rst index 7b7fb082f1..ff4e986178 100644 --- a/doc/model/index.rst +++ b/doc/model/index.rst @@ -22,3 +22,4 @@ Model dprc linear pairtab + change-bias diff --git a/source/tests/pt/test_change_bias.py b/source/tests/pt/test_change_bias.py new file mode 100644 index 0000000000..67a08730ea --- /dev/null +++ b/source/tests/pt/test_change_bias.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +import os +import shutil +import unittest +from copy import ( + deepcopy, +) +from pathlib import ( + Path, +) + +import numpy as np +import torch + +from deepmd.pt.entrypoints.main import ( + get_trainer, +) +from deepmd.pt.train.training import ( + get_model_for_wrapper, + model_change_out_bias, +) +from deepmd.pt.train.wrapper import ( + ModelWrapper, +) +from deepmd.pt.utils.dataloader import ( + DpLoaderSet, +) +from deepmd.pt.utils.env import ( + DEVICE, +) +from deepmd.pt.utils.stat import ( + make_stat_input, +) +from deepmd.pt.utils.utils import ( + to_torch_tensor, +) + +from .model.test_permutation import ( + model_se_e2_a, +) +from .test_finetune import ( + energy_data_requirement, +) + +current_path = os.getcwd() + + +class TestChangeBias(unittest.TestCase): + def setUp(self): + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + model_name = "change-bias-model.ckpt" + self.data_file = [str(Path(__file__).parent / "water/data/single")] + self.config["training"]["training_data"]["systems"] = self.data_file + self.config["training"]["validation_data"]["systems"] = self.data_file + self.config["model"] = deepcopy(model_se_e2_a) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + self.config["training"]["save_ckpt"] = model_name + self.trainer = get_trainer(deepcopy(self.config)) + self.trainer.run() + self.state_dict_trained = self.trainer.wrapper.model.state_dict() + data = DpLoaderSet( + self.data_file, + batch_size=1, + type_map=self.config["model"]["type_map"], + ) + data.add_data_requirement(energy_data_requirement) + self.sampled = make_stat_input( + data.systems, + data.dataloaders, + nbatches=1, + ) + self.model_path = Path(current_path) / (model_name + ".pt") + self.model_path_data_bias = Path(current_path) / ( + model_name + "data_bias" + ".pt" + ) + self.model_path_user_bias = Path(current_path) / ( + model_name + "user_bias" + ".pt" + ) + + def test_change_bias_with_data(self): + os.system( + f"dp --pt change-bias {self.model_path!s} -s {self.data_file[0]} -o {self.model_path_data_bias!s}" + ) + state_dict = torch.load(str(self.model_path_data_bias), map_location=DEVICE) + model_params = state_dict["model"]["_extra_state"]["model_params"] + model_for_wrapper = get_model_for_wrapper(model_params) + wrapper = ModelWrapper(model_for_wrapper) + wrapper.load_state_dict(state_dict["model"]) + updated_bias = wrapper.model["Default"].get_out_bias() + expected_model = model_change_out_bias( + self.trainer.wrapper.model["Default"], + self.sampled, + _bias_adjust_mode="change-by-statistic", + ) + expected_bias = expected_model.get_out_bias() + torch.testing.assert_close(updated_bias, expected_bias) + + def test_change_bias_with_user_defined(self): + user_bias = [0.1, 3.2, -0.5] + os.system( + f"dp --pt change-bias {self.model_path!s} -b {' '.join([str(_) for _ in user_bias])} -o {self.model_path_user_bias!s}" + ) + state_dict = torch.load(str(self.model_path_user_bias), map_location=DEVICE) + model_params = state_dict["model"]["_extra_state"]["model_params"] + model_for_wrapper = get_model_for_wrapper(model_params) + wrapper = ModelWrapper(model_for_wrapper) + wrapper.load_state_dict(state_dict["model"]) + updated_bias = wrapper.model["Default"].get_out_bias() + expected_bias = to_torch_tensor(np.array(user_bias)).view(updated_bias.shape) + torch.testing.assert_close(updated_bias, expected_bias) + + def tearDown(self): + for f in os.listdir("."): + if f.startswith("change-bias-model") and f.endswith(".pt"): + os.remove(f) + if f in ["lcurve.out"]: + os.remove(f) + if f in ["stat_files"]: + shutil.rmtree(f)