Skip to content

Commit

Permalink
feat(pt/tf): init-(frz)-model use pretrain script (#3926)
Browse files Browse the repository at this point in the history
Support `--use-pretrain-script` for pt&tf when doing init-(frz)-model.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **Tests**
- Enhanced and added new test cases for deep learning model
initialization and evaluation.
- Improved setup and cleanup processes for temporary files and
directories in tests to ensure a cleaner test environment.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
iProzd authored Jul 3, 2024
1 parent 73312f2 commit 1c3e099
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 8 deletions.
4 changes: 3 additions & 1 deletion deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,9 @@ def main_parser() -> argparse.ArgumentParser:
parser_train.add_argument(
"--use-pretrain-script",
action="store_true",
help="Use model parameters from the script of the pretrained model instead of user input when doing finetuning. Note: This behavior is default and unchangeable in TensorFlow.",
help="When performing fine-tuning or init-model, "
"utilize the model parameters provided by the script of the pretrained model rather than relying on user input. "
"It is important to note that in TensorFlow, this behavior is the default and cannot be modified for fine-tuning. ",
)
parser_train.add_argument(
"-o",
Expand Down
15 changes: 15 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,21 @@ def train(FLAGS):
model_branch=FLAGS.model_branch,
change_model_params=FLAGS.use_pretrain_script,
)
# update init_model or init_frz_model config if necessary
if (
FLAGS.init_model is not None or FLAGS.init_frz_model is not None
) and FLAGS.use_pretrain_script:
if FLAGS.init_model is not None:
init_state_dict = torch.load(FLAGS.init_model, map_location=DEVICE)
if "model" in init_state_dict:
init_state_dict = init_state_dict["model"]
config["model"] = init_state_dict["_extra_state"]["model_params"]
else:
config["model"] = json.loads(
torch.jit.load(
FLAGS.init_frz_model, map_location=DEVICE
).get_model_def_script()
)

# argcheck
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
Expand Down
39 changes: 39 additions & 0 deletions deepmd/tf/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def train(
is_compress: bool = False,
skip_neighbor_stat: bool = False,
finetune: Optional[str] = None,
use_pretrain_script: bool = False,
**kwargs,
):
"""Run DeePMD model training.
Expand Down Expand Up @@ -93,6 +94,9 @@ def train(
skip checking neighbor statistics
finetune : Optional[str]
path to pretrained model or None
use_pretrain_script : bool
Whether to use model script in pretrained model when doing init-model or init-frz-model.
Note that this option is true and unchangeable for fine-tuning.
**kwargs
additional arguments
Expand Down Expand Up @@ -123,6 +127,41 @@ def train(
jdata, run_opt.finetune
)

if (
run_opt.init_model is not None or run_opt.init_frz_model is not None
) and use_pretrain_script:
from deepmd.tf.utils.errors import (
GraphWithoutTensorError,
)
from deepmd.tf.utils.graph import (
get_tensor_by_name,
get_tensor_by_name_from_graph,
)

err_msg = (
f"The input model: {run_opt.init_model if run_opt.init_model is not None else run_opt.init_frz_model} has no training script, "
f"Please use the model pretrained with v2.1.5 or higher version of DeePMD-kit."
)
if run_opt.init_model is not None:
with tf.Graph().as_default() as graph:
tf.train.import_meta_graph(
f"{run_opt.init_model}.meta", clear_devices=True
)
try:
t_training_script = get_tensor_by_name_from_graph(
graph, "train_attr/training_script"
)
except GraphWithoutTensorError as e:
raise RuntimeError(err_msg) from e
else:
try:
t_training_script = get_tensor_by_name(
run_opt.init_frz_model, "train_attr/training_script"
)
except GraphWithoutTensorError as e:
raise RuntimeError(err_msg) from e
jdata["model"] = json.loads(t_training_script)["model"]

jdata = update_deepmd_input(jdata, warning=True, dump="input_v2_compat.json")

jdata = normalize(jdata)
Expand Down
56 changes: 49 additions & 7 deletions source/tests/pt/test_init_frz_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
import os
import shutil
import tempfile
import unittest
from argparse import (
Namespace,
Expand All @@ -21,12 +24,17 @@
DeepPot,
)

from .common import (
run_dp,
)


class TestInitFrzModel(unittest.TestCase):
def setUp(self):
input_json = str(Path(__file__).parent / "water/se_atten.json")
with open(input_json) as f:
config = json.load(f)
config["model"]["descriptor"]["smooth_type_embedding"] = True
config["training"]["numb_steps"] = 1
config["training"]["save_freq"] = 1
config["learning_rate"]["start_lr"] = 1.0
Expand All @@ -38,15 +46,30 @@ def setUp(self):
]

self.models = []
for imodel in range(2):
if imodel == 1:
config["training"]["numb_steps"] = 0
trainer = get_trainer(deepcopy(config), init_frz_model=self.models[-1])
for imodel in range(3):
frozen_model = f"frozen_model{imodel}.pth"
if imodel == 0:
temp_config = deepcopy(config)
trainer = get_trainer(temp_config)
elif imodel == 1:
temp_config = deepcopy(config)
temp_config["training"]["numb_steps"] = 0
trainer = get_trainer(temp_config, init_frz_model=self.models[-1])
else:
trainer = get_trainer(deepcopy(config))
trainer.run()
empty_config = deepcopy(config)
empty_config["model"]["descriptor"] = {}
empty_config["model"]["fitting_net"] = {}
empty_config["training"]["numb_steps"] = 0
tmp_input = tempfile.NamedTemporaryFile(delete=False, suffix=".json")
with open(tmp_input.name, "w") as f:
json.dump(empty_config, f, indent=4)
run_dp(
f"dp --pt train {tmp_input.name} --init-frz-model {self.models[-1]} --use-pretrain-script --skip-neighbor-stat"
)
trainer = None

frozen_model = f"frozen_model{imodel}.pth"
if imodel in [0, 1]:
trainer.run()
ns = Namespace(
model="model.pt",
output=frozen_model,
Expand All @@ -58,6 +81,7 @@ def setUp(self):
def test_dp_test(self):
dp1 = DeepPot(str(self.models[0]))
dp2 = DeepPot(str(self.models[1]))
dp3 = DeepPot(str(self.models[2]))
cell = np.array(
[
5.122106549439247480e00,
Expand Down Expand Up @@ -96,8 +120,26 @@ def test_dp_test(self):
e1, f1, v1, ae1, av1 = ret1[0], ret1[1], ret1[2], ret1[3], ret1[4]
ret2 = dp2.eval(coord, cell, atype, atomic=True)
e2, f2, v2, ae2, av2 = ret2[0], ret2[1], ret2[2], ret2[3], ret2[4]
ret3 = dp3.eval(coord, cell, atype, atomic=True)
e3, f3, v3, ae3, av3 = ret3[0], ret3[1], ret3[2], ret3[3], ret3[4]
np.testing.assert_allclose(e1, e2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(e1, e3, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(f1, f2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(f1, f3, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(v1, v2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(v1, v3, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(ae1, ae2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(ae1, ae3, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(av1, av2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(av1, av3, rtol=1e-10, atol=1e-10)

def tearDown(self):
for f in os.listdir("."):
if f.startswith("frozen_model") and f.endswith(".pth"):
os.remove(f)
if f.startswith("model") and f.endswith(".pt"):
os.remove(f)
if f in ["lcurve.out"]:
os.remove(f)
if f in ["stat_files"]:
shutil.rmtree(f)
136 changes: 136 additions & 0 deletions source/tests/pt/test_init_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
import os
import shutil
import tempfile
import unittest
from copy import (
deepcopy,
)
from pathlib import (
Path,
)

import numpy as np

from deepmd.pt.entrypoints.main import (
get_trainer,
)
from deepmd.pt.infer.deep_eval import (
DeepPot,
)

from .common import (
run_dp,
)


class TestInitModel(unittest.TestCase):
def setUp(self):
input_json = str(Path(__file__).parent / "water/se_atten.json")
with open(input_json) as f:
config = json.load(f)
config["model"]["descriptor"]["smooth_type_embedding"] = True
config["training"]["numb_steps"] = 1
config["training"]["save_freq"] = 1
config["learning_rate"]["start_lr"] = 1.0
config["training"]["training_data"]["systems"] = [
str(Path(__file__).parent / "water/data/single")
]
config["training"]["validation_data"]["systems"] = [
str(Path(__file__).parent / "water/data/single")
]

self.models = []
for imodel in range(3):
ckpt_model = f"model{imodel}.ckpt"
if imodel == 0:
temp_config = deepcopy(config)
temp_config["training"]["save_ckpt"] = ckpt_model
trainer = get_trainer(temp_config)
elif imodel == 1:
temp_config = deepcopy(config)
temp_config["training"]["numb_steps"] = 0
temp_config["training"]["save_ckpt"] = ckpt_model
trainer = get_trainer(temp_config, init_model=self.models[-1])
else:
empty_config = deepcopy(config)
empty_config["model"]["descriptor"] = {}
empty_config["model"]["fitting_net"] = {}
empty_config["training"]["numb_steps"] = 0
empty_config["training"]["save_ckpt"] = ckpt_model
tmp_input = tempfile.NamedTemporaryFile(delete=False, suffix=".json")
with open(tmp_input.name, "w") as f:
json.dump(empty_config, f, indent=4)
run_dp(
f"dp --pt train {tmp_input.name} --init-model {self.models[-1]} --use-pretrain-script --skip-neighbor-stat"
)
trainer = None

if imodel in [0, 1]:
trainer.run()
self.models.append(ckpt_model + ".pt")

def test_dp_test(self):
dp1 = DeepPot(str(self.models[0]))
dp2 = DeepPot(str(self.models[1]))
dp3 = DeepPot(str(self.models[2]))
cell = np.array(
[
5.122106549439247480e00,
4.016537340154059388e-01,
6.951654033828678081e-01,
4.016537340154059388e-01,
6.112136112297989143e00,
8.178091365465004481e-01,
6.951654033828678081e-01,
8.178091365465004481e-01,
6.159552512682983760e00,
]
).reshape(1, 3, 3)
coord = np.array(
[
2.978060152121375648e00,
3.588469695887098077e00,
2.792459820604495491e00,
3.895592322591093115e00,
2.712091020667753760e00,
1.366836847133650501e00,
9.955616170888935690e-01,
4.121324820711413039e00,
1.817239061889086571e00,
3.553661462345699906e00,
5.313046969500791583e00,
6.635182659098815883e00,
6.088601018589653080e00,
6.575011420004332585e00,
6.825240650611076099e00,
]
).reshape(1, -1, 3)
atype = np.array([0, 0, 0, 1, 1]).reshape(1, -1)

ret1 = dp1.eval(coord, cell, atype, atomic=True)
e1, f1, v1, ae1, av1 = ret1[0], ret1[1], ret1[2], ret1[3], ret1[4]
ret2 = dp2.eval(coord, cell, atype, atomic=True)
e2, f2, v2, ae2, av2 = ret2[0], ret2[1], ret2[2], ret2[3], ret2[4]
ret3 = dp3.eval(coord, cell, atype, atomic=True)
e3, f3, v3, ae3, av3 = ret3[0], ret3[1], ret3[2], ret3[3], ret3[4]
np.testing.assert_allclose(e1, e2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(e1, e3, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(f1, f2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(f1, f3, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(v1, v2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(v1, v3, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(ae1, ae2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(ae1, ae3, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(av1, av2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(av1, av3, rtol=1e-10, atol=1e-10)

def tearDown(self):
for f in os.listdir("."):
if f.startswith("model") and f.endswith(".pt"):
os.remove(f)
if f in ["lcurve.out"]:
os.remove(f)
if f in ["stat_files"]:
shutil.rmtree(f)

0 comments on commit 1c3e099

Please sign in to comment.