Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pt): calculate stat during compression if --skip-neighbor-stat #4330

Merged
merged 1 commit into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions deepmd/pt/entrypoints/compress.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,32 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
import logging
from typing import (
Optional,
)

import torch

from deepmd.common import (
j_loader,
)
from deepmd.pt.model.model import (
get_model,
)
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.update_sel import (
UpdateSel,
)
from deepmd.utils.compat import (
update_deepmd_input,
)
from deepmd.utils.data_system import (
get_data,
)

log = logging.getLogger(__name__)


def enable_compression(
Expand All @@ -14,12 +35,44 @@
stride: float = 0.01,
extrapolate: int = 5,
check_frequency: int = -1,
training_script: Optional[str] = None,
):
saved_model = torch.jit.load(input_file, map_location="cpu")
model_def_script = json.loads(saved_model.model_def_script)
model = get_model(model_def_script)
model.load_state_dict(saved_model.state_dict())

if model.get_min_nbor_dist() is None:
log.info(
"Minimal neighbor distance is not saved in the model, compute it from the training data."
)
if training_script is None:
raise ValueError(

Check warning on line 50 in deepmd/pt/entrypoints/compress.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/compress.py#L50

Added line #L50 was not covered by tests
"The model does not have a minimum neighbor distance, "
"so the training script and data must be provided "
"(via -t,--training-script)."
)

jdata = j_loader(training_script)
jdata = update_deepmd_input(jdata)

type_map = jdata["model"].get("type_map", None)
train_data = get_data(
jdata["training"]["training_data"],
0, # not used
type_map,
None,
)
update_sel = UpdateSel()
t_min_nbor_dist = update_sel.get_min_nbor_dist(
train_data,
)
model.min_nbor_dist = torch.tensor(
t_min_nbor_dist,
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.DEVICE,
)
njzjz marked this conversation as resolved.
Show resolved Hide resolved

model.enable_compression(
extrapolate,
stride,
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None):
stride=FLAGS.step,
extrapolate=FLAGS.extrapolate,
check_frequency=FLAGS.frequency,
training_script=FLAGS.training_script,
)
else:
raise RuntimeError(f"Invalid command {FLAGS.command}!")
Expand Down
159 changes: 158 additions & 1 deletion source/tests/pt/test_model_compression_se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,48 @@ def _init_models_exclude_types():
return INPUT, frozen_model, compressed_model


def _init_models_skip_neighbor_stat():
suffix = "-skip-neighbor-stat"
data_file = str(tests_path / os.path.join("model_compression", "data"))
frozen_model = str(tests_path / f"dp-original{suffix}.pth")
compressed_model = str(tests_path / f"dp-compressed{suffix}.pth")
INPUT = str(tests_path / "input.json")
jdata = j_loader(str(tests_path / os.path.join("model_compression", "input.json")))
jdata["training"]["training_data"]["systems"] = data_file
with open(INPUT, "w") as fp:
json.dump(jdata, fp, indent=4)

ret = run_dp("dp --pt train " + INPUT + " --skip-neighbor-stat")
np.testing.assert_equal(ret, 0, "DP train failed!")
ret = run_dp("dp --pt freeze -o " + frozen_model)
np.testing.assert_equal(ret, 0, "DP freeze failed!")
ret = run_dp(
"dp --pt compress "
+ " -i "
+ frozen_model
+ " -o "
+ compressed_model
+ " -t "
+ INPUT
)
np.testing.assert_equal(ret, 0, "DP model compression failed!")
return INPUT, frozen_model, compressed_model


def setUpModule():
global \
INPUT, \
FROZEN_MODEL, \
COMPRESSED_MODEL, \
INPUT_ET, \
FROZEN_MODEL_ET, \
COMPRESSED_MODEL_ET
COMPRESSED_MODEL_ET, \
FROZEN_MODEL_SKIP_NEIGHBOR_STAT, \
COMPRESSED_MODEL_SKIP_NEIGHBOR_STAT
INPUT, FROZEN_MODEL, COMPRESSED_MODEL = _init_models()
_, FROZEN_MODEL_SKIP_NEIGHBOR_STAT, COMPRESSED_MODEL_SKIP_NEIGHBOR_STAT = (
_init_models_skip_neighbor_stat()
)
INPUT_ET, FROZEN_MODEL_ET, COMPRESSED_MODEL_ET = _init_models_exclude_types()


Expand Down Expand Up @@ -572,5 +605,129 @@ def test_2frame_atm(self):
np.testing.assert_almost_equal(vv0, vv1, default_places)


class TestSkipNeighborStat(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.dp_original = DeepEval(FROZEN_MODEL_SKIP_NEIGHBOR_STAT)
cls.dp_compressed = DeepEval(COMPRESSED_MODEL_SKIP_NEIGHBOR_STAT)
cls.coords = np.array(
[
12.83,
2.56,
2.18,
12.09,
2.87,
2.74,
00.25,
3.32,
1.68,
3.36,
3.00,
1.81,
3.51,
2.51,
2.60,
4.27,
3.22,
1.56,
]
)
cls.atype = [0, 1, 1, 0, 1, 1]
cls.box = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0])

def test_attrs(self):
self.assertEqual(self.dp_original.get_ntypes(), 2)
self.assertAlmostEqual(self.dp_original.get_rcut(), 6.0, places=default_places)
self.assertEqual(self.dp_original.get_type_map(), ["O", "H"])
self.assertEqual(self.dp_original.get_dim_fparam(), 0)
self.assertEqual(self.dp_original.get_dim_aparam(), 0)

self.assertEqual(self.dp_compressed.get_ntypes(), 2)
self.assertAlmostEqual(
self.dp_compressed.get_rcut(), 6.0, places=default_places
)
self.assertEqual(self.dp_compressed.get_type_map(), ["O", "H"])
self.assertEqual(self.dp_compressed.get_dim_fparam(), 0)
self.assertEqual(self.dp_compressed.get_dim_aparam(), 0)

def test_1frame(self):
ee0, ff0, vv0 = self.dp_original.eval(
self.coords, self.box, self.atype, atomic=False
)
ee1, ff1, vv1 = self.dp_compressed.eval(
self.coords, self.box, self.atype, atomic=False
)
# check shape of the returns
nframes = 1
natoms = len(self.atype)
self.assertEqual(ee0.shape, (nframes, 1))
self.assertEqual(ff0.shape, (nframes, natoms, 3))
self.assertEqual(vv0.shape, (nframes, 9))
self.assertEqual(ee1.shape, (nframes, 1))
self.assertEqual(ff1.shape, (nframes, natoms, 3))
self.assertEqual(vv1.shape, (nframes, 9))
# check values
np.testing.assert_almost_equal(ff0, ff1, default_places)
np.testing.assert_almost_equal(ee0, ee1, default_places)
np.testing.assert_almost_equal(vv0, vv1, default_places)

def test_1frame_atm(self):
ee0, ff0, vv0, ae0, av0 = self.dp_original.eval(
self.coords, self.box, self.atype, atomic=True
)
ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval(
self.coords, self.box, self.atype, atomic=True
)
# check shape of the returns
nframes = 1
natoms = len(self.atype)
self.assertEqual(ee0.shape, (nframes, 1))
self.assertEqual(ff0.shape, (nframes, natoms, 3))
self.assertEqual(vv0.shape, (nframes, 9))
self.assertEqual(ae0.shape, (nframes, natoms, 1))
self.assertEqual(av0.shape, (nframes, natoms, 9))
self.assertEqual(ee1.shape, (nframes, 1))
self.assertEqual(ff1.shape, (nframes, natoms, 3))
self.assertEqual(vv1.shape, (nframes, 9))
self.assertEqual(ae1.shape, (nframes, natoms, 1))
self.assertEqual(av1.shape, (nframes, natoms, 9))
# check values
np.testing.assert_almost_equal(ff0, ff1, default_places)
np.testing.assert_almost_equal(ae0, ae1, default_places)
np.testing.assert_almost_equal(av0, av1, default_places)
np.testing.assert_almost_equal(ee0, ee1, default_places)
np.testing.assert_almost_equal(vv0, vv1, default_places)

def test_2frame_atm(self):
coords2 = np.concatenate((self.coords, self.coords))
box2 = np.concatenate((self.box, self.box))
ee0, ff0, vv0, ae0, av0 = self.dp_original.eval(
coords2, box2, self.atype, atomic=True
)
ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval(
coords2, box2, self.atype, atomic=True
)
# check shape of the returns
nframes = 2
natoms = len(self.atype)
self.assertEqual(ee0.shape, (nframes, 1))
self.assertEqual(ff0.shape, (nframes, natoms, 3))
self.assertEqual(vv0.shape, (nframes, 9))
self.assertEqual(ae0.shape, (nframes, natoms, 1))
self.assertEqual(av0.shape, (nframes, natoms, 9))
self.assertEqual(ee1.shape, (nframes, 1))
self.assertEqual(ff1.shape, (nframes, natoms, 3))
self.assertEqual(vv1.shape, (nframes, 9))
self.assertEqual(ae1.shape, (nframes, natoms, 1))
self.assertEqual(av1.shape, (nframes, natoms, 9))

# check values
np.testing.assert_almost_equal(ff0, ff1, default_places)
np.testing.assert_almost_equal(ae0, ae1, default_places)
np.testing.assert_almost_equal(av0, av1, default_places)
np.testing.assert_almost_equal(ee0, ee1, default_places)
np.testing.assert_almost_equal(vv0, vv1, default_places)


if __name__ == "__main__":
unittest.main()