Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pt/tf): init-(frz)-model use pretrain script #3926

Merged
merged 5 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,9 @@ def main_parser() -> argparse.ArgumentParser:
parser_train.add_argument(
"--use-pretrain-script",
action="store_true",
help="Use model parameters from the script of the pretrained model instead of user input when doing finetuning. Note: This behavior is default and unchangeable in TensorFlow.",
help="When performing fine-tuning or init-model, "
"utilize the model parameters provided by the script of the pretrained model rather than relying on user input. "
"It is important to note that in TensorFlow, this behavior is the default and cannot be modified for fine-tuning. ",
)
parser_train.add_argument(
"-o",
Expand Down
15 changes: 15 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,21 @@
model_branch=FLAGS.model_branch,
change_model_params=FLAGS.use_pretrain_script,
)
# update init_model or init_frz_model config if necessary
if (

Check warning on line 247 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L247

Added line #L247 was not covered by tests
FLAGS.init_model is not None or FLAGS.init_frz_model is not None
) and FLAGS.use_pretrain_script:
if FLAGS.init_model is not None:
init_state_dict = torch.load(FLAGS.init_model, map_location=DEVICE)
if "model" in init_state_dict:
init_state_dict = init_state_dict["model"]
config["model"] = init_state_dict["_extra_state"]["model_params"]

Check warning on line 254 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L250-L254

Added lines #L250 - L254 were not covered by tests
else:
config["model"] = json.loads(

Check warning on line 256 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L256

Added line #L256 was not covered by tests
torch.jit.load(
FLAGS.init_frz_model, map_location=DEVICE
).get_model_def_script()
)

# argcheck
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
Expand Down
39 changes: 39 additions & 0 deletions deepmd/tf/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
is_compress: bool = False,
skip_neighbor_stat: bool = False,
finetune: Optional[str] = None,
use_pretrain_script: bool = False,
**kwargs,
):
"""Run DeePMD model training.
Expand Down Expand Up @@ -93,6 +94,9 @@
skip checking neighbor statistics
finetune : Optional[str]
path to pretrained model or None
use_pretrain_script : bool
Whether to use model script in pretrained model when doing init-model or init-frz-model.
Note that this option is true and unchangeable for fine-tuning.
**kwargs
additional arguments

Expand Down Expand Up @@ -123,6 +127,41 @@
jdata, run_opt.finetune
)

if (

Check warning on line 130 in deepmd/tf/entrypoints/train.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/entrypoints/train.py#L130

Added line #L130 was not covered by tests
run_opt.init_model is not None or run_opt.init_frz_model is not None
) and use_pretrain_script:
from deepmd.tf.utils.errors import (

Check warning on line 133 in deepmd/tf/entrypoints/train.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/entrypoints/train.py#L133

Added line #L133 was not covered by tests
GraphWithoutTensorError,
)
from deepmd.tf.utils.graph import (

Check warning on line 136 in deepmd/tf/entrypoints/train.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/entrypoints/train.py#L136

Added line #L136 was not covered by tests
get_tensor_by_name,
get_tensor_by_name_from_graph,
)

err_msg = (

Check warning on line 141 in deepmd/tf/entrypoints/train.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/entrypoints/train.py#L141

Added line #L141 was not covered by tests
f"The input model: {run_opt.init_model if run_opt.init_model is not None else run_opt.init_frz_model} has no training script, "
f"Please use the model pretrained with v2.1.5 or higher version of DeePMD-kit."
)
if run_opt.init_model is not None:
with tf.Graph().as_default() as graph:
tf.train.import_meta_graph(

Check warning on line 147 in deepmd/tf/entrypoints/train.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/entrypoints/train.py#L145-L147

Added lines #L145 - L147 were not covered by tests
f"{run_opt.init_model}.meta", clear_devices=True
)
try:
t_training_script = get_tensor_by_name_from_graph(

Check warning on line 151 in deepmd/tf/entrypoints/train.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/entrypoints/train.py#L150-L151

Added lines #L150 - L151 were not covered by tests
graph, "train_attr/training_script"
)
except GraphWithoutTensorError as e:
raise RuntimeError(err_msg) from e

Check warning on line 155 in deepmd/tf/entrypoints/train.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/entrypoints/train.py#L154-L155

Added lines #L154 - L155 were not covered by tests
else:
try:
t_training_script = get_tensor_by_name(

Check warning on line 158 in deepmd/tf/entrypoints/train.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/entrypoints/train.py#L157-L158

Added lines #L157 - L158 were not covered by tests
run_opt.init_frz_model, "train_attr/training_script"
)
except GraphWithoutTensorError as e:
raise RuntimeError(err_msg) from e
jdata["model"] = json.loads(t_training_script)["model"]

Check warning on line 163 in deepmd/tf/entrypoints/train.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/entrypoints/train.py#L161-L163

Added lines #L161 - L163 were not covered by tests

jdata = update_deepmd_input(jdata, warning=True, dump="input_v2_compat.json")

jdata = normalize(jdata)
Expand Down
52 changes: 45 additions & 7 deletions source/tests/pt/test_init_frz_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
import os
import shutil
import tempfile
import unittest
from argparse import (
Namespace,
Expand Down Expand Up @@ -27,6 +30,7 @@ def setUp(self):
input_json = str(Path(__file__).parent / "water/se_atten.json")
with open(input_json) as f:
config = json.load(f)
config["model"]["descriptor"]["smooth_type_embedding"] = True
config["training"]["numb_steps"] = 1
config["training"]["save_freq"] = 1
config["learning_rate"]["start_lr"] = 1.0
Expand All @@ -38,15 +42,30 @@ def setUp(self):
]

self.models = []
for imodel in range(2):
if imodel == 1:
config["training"]["numb_steps"] = 0
trainer = get_trainer(deepcopy(config), init_frz_model=self.models[-1])
for imodel in range(3):
frozen_model = f"frozen_model{imodel}.pth"
if imodel == 0:
temp_config = deepcopy(config)
trainer = get_trainer(temp_config)
elif imodel == 1:
temp_config = deepcopy(config)
temp_config["training"]["numb_steps"] = 0
trainer = get_trainer(temp_config, init_frz_model=self.models[-1])
else:
trainer = get_trainer(deepcopy(config))
trainer.run()
empty_config = deepcopy(config)
empty_config["model"]["descriptor"] = {}
empty_config["model"]["fitting_net"] = {}
empty_config["training"]["numb_steps"] = 0
tmp_input = tempfile.NamedTemporaryFile(delete=False, suffix=".json")
with open(tmp_input.name, "w") as f:
json.dump(empty_config, f, indent=4)
os.system(
f"dp --pt train {tmp_input.name} --init-frz-model {self.models[-1]} --use-pretrain-script --skip-neighbor-stat"
)
iProzd marked this conversation as resolved.
Show resolved Hide resolved
trainer = None

frozen_model = f"frozen_model{imodel}.pth"
if imodel in [0, 1]:
trainer.run()
ns = Namespace(
model="model.pt",
output=frozen_model,
Expand All @@ -58,6 +77,7 @@ def setUp(self):
def test_dp_test(self):
dp1 = DeepPot(str(self.models[0]))
dp2 = DeepPot(str(self.models[1]))
dp3 = DeepPot(str(self.models[2]))
cell = np.array(
[
5.122106549439247480e00,
Expand Down Expand Up @@ -96,8 +116,26 @@ def test_dp_test(self):
e1, f1, v1, ae1, av1 = ret1[0], ret1[1], ret1[2], ret1[3], ret1[4]
ret2 = dp2.eval(coord, cell, atype, atomic=True)
e2, f2, v2, ae2, av2 = ret2[0], ret2[1], ret2[2], ret2[3], ret2[4]
ret3 = dp3.eval(coord, cell, atype, atomic=True)
e3, f3, v3, ae3, av3 = ret3[0], ret3[1], ret3[2], ret3[3], ret3[4]
np.testing.assert_allclose(e1, e2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(e1, e3, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(f1, f2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(f1, f3, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(v1, v2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(v1, v3, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(ae1, ae2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(ae1, ae3, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(av1, av2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(av1, av3, rtol=1e-10, atol=1e-10)

def tearDown(self):
for f in os.listdir("."):
if f.startswith("model") and f.endswith(".pt"):
os.remove(f)
if f.startswith("frozen_model") and f.endswith(".pth"):
os.remove(f)
if f in ["lcurve.out"]:
os.remove(f)
if f in ["stat_files"]:
shutil.rmtree(f)
132 changes: 132 additions & 0 deletions source/tests/pt/test_init_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
import os
import shutil
import tempfile
import unittest
from copy import (
deepcopy,
)
from pathlib import (
Path,
)

import numpy as np

from deepmd.pt.entrypoints.main import (
get_trainer,
)
from deepmd.pt.infer.deep_eval import (
DeepPot,
)


class TestInitModel(unittest.TestCase):
def setUp(self):
input_json = str(Path(__file__).parent / "water/se_atten.json")
with open(input_json) as f:
config = json.load(f)
config["model"]["descriptor"]["smooth_type_embedding"] = True
config["training"]["numb_steps"] = 1
config["training"]["save_freq"] = 1
config["learning_rate"]["start_lr"] = 1.0
config["training"]["training_data"]["systems"] = [
str(Path(__file__).parent / "water/data/single")
]
config["training"]["validation_data"]["systems"] = [
str(Path(__file__).parent / "water/data/single")
]

self.models = []
for imodel in range(3):
ckpt_model = f"model{imodel}.ckpt"
if imodel == 0:
temp_config = deepcopy(config)
temp_config["training"]["save_ckpt"] = ckpt_model
trainer = get_trainer(temp_config)
elif imodel == 1:
temp_config = deepcopy(config)
temp_config["training"]["numb_steps"] = 0
temp_config["training"]["save_ckpt"] = ckpt_model
trainer = get_trainer(temp_config, init_model=self.models[-1])
else:
empty_config = deepcopy(config)
empty_config["model"]["descriptor"] = {}
empty_config["model"]["fitting_net"] = {}
empty_config["training"]["numb_steps"] = 0
empty_config["training"]["save_ckpt"] = ckpt_model
tmp_input = tempfile.NamedTemporaryFile(delete=False, suffix=".json")
with open(tmp_input.name, "w") as f:
json.dump(empty_config, f, indent=4)
os.system(
f"dp --pt train {tmp_input.name} --init-model {self.models[-1]} --use-pretrain-script --skip-neighbor-stat"
)
iProzd marked this conversation as resolved.
Show resolved Hide resolved
trainer = None

if imodel in [0, 1]:
trainer.run()
self.models.append(ckpt_model + ".pt")

def test_dp_test(self):
dp1 = DeepPot(str(self.models[0]))
dp2 = DeepPot(str(self.models[1]))
dp3 = DeepPot(str(self.models[2]))
cell = np.array(
[
5.122106549439247480e00,
4.016537340154059388e-01,
6.951654033828678081e-01,
4.016537340154059388e-01,
6.112136112297989143e00,
8.178091365465004481e-01,
6.951654033828678081e-01,
8.178091365465004481e-01,
6.159552512682983760e00,
]
).reshape(1, 3, 3)
coord = np.array(
[
2.978060152121375648e00,
3.588469695887098077e00,
2.792459820604495491e00,
3.895592322591093115e00,
2.712091020667753760e00,
1.366836847133650501e00,
9.955616170888935690e-01,
4.121324820711413039e00,
1.817239061889086571e00,
3.553661462345699906e00,
5.313046969500791583e00,
6.635182659098815883e00,
6.088601018589653080e00,
6.575011420004332585e00,
6.825240650611076099e00,
]
).reshape(1, -1, 3)
atype = np.array([0, 0, 0, 1, 1]).reshape(1, -1)

ret1 = dp1.eval(coord, cell, atype, atomic=True)
e1, f1, v1, ae1, av1 = ret1[0], ret1[1], ret1[2], ret1[3], ret1[4]
ret2 = dp2.eval(coord, cell, atype, atomic=True)
e2, f2, v2, ae2, av2 = ret2[0], ret2[1], ret2[2], ret2[3], ret2[4]
ret3 = dp3.eval(coord, cell, atype, atomic=True)
e3, f3, v3, ae3, av3 = ret3[0], ret3[1], ret3[2], ret3[3], ret3[4]
np.testing.assert_allclose(e1, e2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(e1, e3, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(f1, f2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(f1, f3, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(v1, v2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(v1, v3, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(ae1, ae2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(ae1, ae3, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(av1, av2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(av1, av3, rtol=1e-10, atol=1e-10)

def tearDown(self):
for f in os.listdir("."):
if f.startswith("model") and f.endswith(".pt"):
os.remove(f)
if f in ["lcurve.out"]:
os.remove(f)
if f in ["stat_files"]:
shutil.rmtree(f)
Loading