Skip to content

Commit

Permalink
Chore(pt): refactor the command function interface
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Oct 16, 2024
1 parent 5050f61 commit 3e08cfd
Showing 1 changed file with 96 additions and 54 deletions.
150 changes: 96 additions & 54 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,16 +239,27 @@ def get_backend_info(self) -> dict:
}


def train(FLAGS):
log.info("Configuration path: %s", FLAGS.INPUT)
def train(
input_file: str,
init_model: Optional[str],
restart: Optional[str],
finetune: Optional[str],
init_frz_model: Optional[str],
model_branch: str,
skip_neighbor_stat: bool = False,
use_pretrain_script: bool = False,
force_load: bool = False,
output: str = "out.json",
):
log.info("Configuration path: %s", input_file)
SummaryPrinter()()
with open(FLAGS.INPUT) as fin:
with open(input_file) as fin:
config = json.load(fin)
# ensure suffix, as in the command line help, we say "path prefix of checkpoint files"
if FLAGS.init_model is not None and not FLAGS.init_model.endswith(".pt"):
FLAGS.init_model += ".pt"
if FLAGS.restart is not None and not FLAGS.restart.endswith(".pt"):
FLAGS.restart += ".pt"
if init_model is not None and not init_model.endswith(".pt"):
init_model += ".pt"
if restart is not None and not restart.endswith(".pt"):
restart += ".pt"

# update multitask config
multi_task = "model_dict" in config["model"]
Expand All @@ -262,26 +273,26 @@ def train(FLAGS):

# update fine-tuning config
finetune_links = None
if FLAGS.finetune is not None:
if finetune is not None:
config["model"], finetune_links = get_finetune_rules(
FLAGS.finetune,
finetune,
config["model"],
model_branch=FLAGS.model_branch,
change_model_params=FLAGS.use_pretrain_script,
model_branch=model_branch,
change_model_params=use_pretrain_script,
)
# update init_model or init_frz_model config if necessary
if (
FLAGS.init_model is not None or FLAGS.init_frz_model is not None
) and FLAGS.use_pretrain_script:
if FLAGS.init_model is not None:
init_state_dict = torch.load(FLAGS.init_model, map_location=DEVICE)
init_model is not None or init_frz_model is not None
) and use_pretrain_script:
if init_model is not None:
init_state_dict = torch.load(init_model, map_location=DEVICE)
if "model" in init_state_dict:
init_state_dict = init_state_dict["model"]
config["model"] = init_state_dict["_extra_state"]["model_params"]
else:
config["model"] = json.loads(
torch.jit.load(
FLAGS.init_frz_model, map_location=DEVICE
init_frz_model, map_location=DEVICE
).get_model_def_script()
)

Expand All @@ -291,7 +302,7 @@ def train(FLAGS):

# do neighbor stat
min_nbor_dist = None
if not FLAGS.skip_neighbor_stat:
if not skip_neighbor_stat:
log.info(
"Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)"
)
Expand Down Expand Up @@ -320,16 +331,16 @@ def train(FLAGS):
)
)

with open(FLAGS.output, "w") as fp:
with open(output, "w") as fp:
json.dump(config, fp, indent=4)

trainer = get_trainer(
config,
FLAGS.init_model,
FLAGS.restart,
FLAGS.finetune,
FLAGS.force_load,
FLAGS.init_frz_model,
init_model,
restart,
finetune,
force_load,
init_frz_model,
shared_links=shared_links,
finetune_links=finetune_links,
)
Expand All @@ -343,26 +354,38 @@ def train(FLAGS):
trainer.run()


def freeze(FLAGS):
model = inference.Tester(FLAGS.model, head=FLAGS.head).model
def freeze(
model: str,
output: str = "frozen_model.pth",
head: Optional[str] = None,
):
model = inference.Tester(model, head=head).model
model.eval()
model = torch.jit.script(model)
extra_files = {}
torch.jit.save(
model,
FLAGS.output,
output,
extra_files,
)
log.info(f"Saved frozen model to {FLAGS.output}")


def change_bias(FLAGS):
if FLAGS.INPUT.endswith(".pt"):
old_state_dict = torch.load(FLAGS.INPUT, map_location=env.DEVICE)
log.info(f"Saved frozen model to {output}")

def change_bias(
input_file: str,
mode: str = "change",
bias_value: Optional[list] = None,
datafile: Optional[str] = None,
system: str = ".",
numb_batch: int = 0,
model_branch: Optional[str] = None,
output: Optional[str] = None,
):
if input_file.endswith(".pt"):
old_state_dict = torch.load(input_file, map_location=env.DEVICE)
model_state_dict = copy.deepcopy(old_state_dict.get("model", old_state_dict))
model_params = model_state_dict["_extra_state"]["model_params"]
elif FLAGS.INPUT.endswith(".pth"):
old_model = torch.jit.load(FLAGS.INPUT, map_location=env.DEVICE)
elif input_file.endswith(".pth"):
old_model = torch.jit.load(input_file, map_location=env.DEVICE)
model_params_string = old_model.get_model_def_script()
model_params = json.loads(model_params_string)
old_state_dict = old_model.state_dict()
Expand All @@ -373,9 +396,8 @@ def change_bias(FLAGS):
"or a frozen model with a .pth extension"
)
multi_task = "model_dict" in model_params
model_branch = FLAGS.model_branch
bias_adjust_mode = (
"change-by-statistic" if FLAGS.mode == "change" else "set-by-statistic"
"change-by-statistic" if mode == "change" else "set-by-statistic"
)
if multi_task:
assert (
Expand All @@ -393,24 +415,24 @@ def change_bias(FLAGS):
else model_params["model_dict"][model_branch]["type_map"]
)
model_to_change = model if not multi_task else model[model_branch]
if FLAGS.INPUT.endswith(".pt"):
if input_file.endswith(".pt"):
wrapper = ModelWrapper(model)
wrapper.load_state_dict(old_state_dict["model"])
else:
# for .pth
model.load_state_dict(old_state_dict)

if FLAGS.bias_value is not None:
if bias_value is not None:
# use user-defined bias
assert model_to_change.model_type in [
"ener"
], "User-defined bias is only available for energy model!"
assert (
len(FLAGS.bias_value) == len(type_map)
len(bias_value) == len(type_map)
), f"The number of elements in the bias should be the same as that in the type_map: {type_map}."
old_bias = model_to_change.get_out_bias()
bias_to_set = torch.tensor(
FLAGS.bias_value, dtype=old_bias.dtype, device=old_bias.device
bias_value, dtype=old_bias.dtype, device=old_bias.device
).view(old_bias.shape)
model_to_change.set_out_bias(bias_to_set)
log.info(
Expand All @@ -421,11 +443,11 @@ def change_bias(FLAGS):
updated_model = model_to_change
else:
# calculate bias on given systems
if FLAGS.datafile is not None:
with open(FLAGS.datafile) as datalist:
if datafile is not None:
with open(datafile) as datalist:
all_sys = datalist.read().splitlines()
else:
all_sys = expand_sys_str(FLAGS.system)
all_sys = expand_sys_str(system)
data_systems = process_systems(all_sys)
data_single = DpLoaderSet(
data_systems,
Expand All @@ -438,7 +460,7 @@ def change_bias(FLAGS):
data_requirement = mock_loss.label_requirement
data_requirement += training.get_additional_data_requirement(model_to_change)
data_single.add_data_requirement(data_requirement)
nbatches = FLAGS.numb_batch if FLAGS.numb_batch != 0 else float("inf")
nbatches = numb_batch if numb_batch != 0 else float("inf")
sampled_data = make_stat_input(
data_single.systems,
data_single.dataloaders,
Expand All @@ -453,11 +475,11 @@ def change_bias(FLAGS):
else:
model[model_branch] = updated_model

if FLAGS.INPUT.endswith(".pt"):
if input_file.endswith(".pt"):
output_path = (
FLAGS.output
if FLAGS.output is not None
else FLAGS.INPUT.replace(".pt", "_updated.pt")
output
if output is not None
else input_file.replace(".pt", "_updated.pt")
)
wrapper = ModelWrapper(model)
if "model" in old_state_dict:
Expand All @@ -470,9 +492,9 @@ def change_bias(FLAGS):
else:
# for .pth
output_path = (
FLAGS.output
if FLAGS.output is not None
else FLAGS.INPUT.replace(".pth", "_updated.pth")
output
if output is not None
else input_file.replace(".pth", "_updated.pth")
)
model = torch.jit.script(model)
torch.jit.save(
Expand All @@ -499,7 +521,18 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None):
log.info("DeePMD version: %s", __version__)

if FLAGS.command == "train":
train(FLAGS)
train(
input_file=FLAGS.INPUT,
init_model=FLAGS.init_model,
restart=FLAGS.restart,
finetune=FLAGS.finetune,
init_frz_model=FLAGS.init_frz_model,
model_branch=FLAGS.model_branch,
skip_neighbor_stat=FLAGS.skip_neighbor_stat,
use_pretrain_script=FLAGS.use_pretrain_script,
force_load=FLAGS.force_load,
output=FLAGS.output,
)
elif FLAGS.command == "freeze":
if Path(FLAGS.checkpoint_folder).is_dir():
checkpoint_path = Path(FLAGS.checkpoint_folder)
Expand All @@ -508,9 +541,18 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None):
else:
FLAGS.model = FLAGS.checkpoint_folder
FLAGS.output = str(Path(FLAGS.output).with_suffix(".pth"))
freeze(FLAGS)
freeze(model=FLAGS.model, output=FLAGS.output, head=FLAGS.head)
elif FLAGS.command == "change-bias":
change_bias(FLAGS)
change_bias(
input_file=FLAGS.INPUT,
mode=FLAGS.mode,
bias_value=FLAGS.bias_value,
datafile=FLAGS.datafile,
system=FLAGS.system,
numb_batch=FLAGS.numb_batch,
model_branch=FLAGS.model_branch,
output=FLAGS.output,
)
else:
raise RuntimeError(f"Invalid command {FLAGS.command}!")

Expand Down

0 comments on commit 3e08cfd

Please sign in to comment.