Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: support preset bias of atomic model output #4116

Merged
merged 8 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from deepmd.pt.model.task import (
BaseFitting,
)
from deepmd.pt.utils.utils import (
to_torch_tensor,
)
from deepmd.utils.spin import (
Spin,
)
Expand Down Expand Up @@ -151,6 +154,22 @@ def get_zbl_model(model_params):
)


def _convert_preset_out_bias_to_torch_tensor(preset_out_bias, type_map):
if preset_out_bias is not None:
if preset_out_bias is not None:
njzjz marked this conversation as resolved.
Show resolved Hide resolved
for kk in preset_out_bias:
if len(preset_out_bias[kk]) != len(type_map):
raise ValueError(
"length of the preset_out_bias should be the same as the type_map"
)
for jj in range(len(preset_out_bias[kk])):
if preset_out_bias[kk][jj] is not None:
preset_out_bias[kk][jj] = to_torch_tensor(
np.array(preset_out_bias[kk][jj])
)
return preset_out_bias


def get_standard_model(model_params):
model_params_old = model_params
model_params = copy.deepcopy(model_params)
Expand All @@ -176,6 +195,10 @@ def get_standard_model(model_params):
fitting = BaseFitting(**fitting_net)
atom_exclude_types = model_params.get("atom_exclude_types", [])
pair_exclude_types = model_params.get("pair_exclude_types", [])
preset_out_bias = model_params.get("preset_out_bias")
preset_out_bias = _convert_preset_out_bias_to_torch_tensor(
preset_out_bias, model_params["type_map"]
)

if fitting_net["type"] == "dipole":
modelcls = DipoleModel
Expand All @@ -196,6 +219,7 @@ def get_standard_model(model_params):
type_map=model_params["type_map"],
atom_exclude_types=atom_exclude_types,
pair_exclude_types=pair_exclude_types,
preset_out_bias=preset_out_bias,
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
)
model.model_def_script = json.dumps(model_params_old)
return model
Expand Down
8 changes: 8 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1771,6 +1771,7 @@
doc_spin = "The settings for systems with spin."
doc_atom_exclude_types = "Exclude the atomic contribution of the listed atom types"
doc_pair_exclude_types = "The atom pairs of the listed types are not treated to be neighbors, i.e. they do not see each other."
doc_preset_out_bias = "The preset bias of the atomic output. Is provided as a dict. Taking the energy model that has three atom types for example, the preset_out_bias may be given as `{ 'energy': [None, 0., 1.] }`. In this case the bias of type 1 and 2 are set to 0. and 1., respectively.The set_davg_zero should be set to true."

Check warning on line 1774 in deepmd/utils/argcheck.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/argcheck.py#L1774

Added line #L1774 was not covered by tests
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
doc_finetune_head = (
"The chosen fitting net to fine-tune on, when doing multi-task fine-tuning. "
"If not set or set to 'RANDOM', the fitting net will be randomly initialized."
Expand Down Expand Up @@ -1833,6 +1834,13 @@
default=[],
doc=doc_only_pt_supported + doc_atom_exclude_types,
),
Argument(
"preset_out_bias",
dict,
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
optional=True,
default=None,
doc=doc_only_pt_supported + doc_preset_out_bias,
),
Argument(
"srtab_add_bias",
bool,
Expand Down
81 changes: 81 additions & 0 deletions source/tests/pt/model/test_get_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import unittest

import torch

from deepmd.pt.model.model import (
get_model,
)
from deepmd.pt.utils import (
env,
)

dtype = torch.float64

model_se_e2_a = {
"type_map": ["O", "H", "B"],
"descriptor": {
"type": "se_e2_a",
"sel": [46, 92, 4],
"rcut_smth": 0.50,
"rcut": 4.00,
"neuron": [25, 50, 100],
"resnet_dt": False,
"axis_neuron": 16,
"seed": 1,
},
"fitting_net": {
"neuron": [24, 24, 24],
"resnet_dt": True,
"seed": 1,
},
"data_stat_nbatch": 20,
"atom_exclude_types": [1],
"pair_exclude_types": [[1, 2]],
"preset_out_bias": {
"energy": [
None,
[1.0],
[3.0],
]
},
}


class TestGetModel(unittest.TestCase):
def test_model_attr(self):
model_params = copy.deepcopy(model_se_e2_a)
self.model = get_model(model_params).to(env.DEVICE)
atomic_model = self.model.atomic_model
self.assertEqual(atomic_model.type_map, ["O", "H", "B"])
self.assertEqual(
atomic_model.preset_out_bias,
{
"energy": [
None,
torch.tensor([1.0], dtype=dtype, device=env.DEVICE),
torch.tensor([3.0], dtype=dtype, device=env.DEVICE),
]
},
)
self.assertEqual(atomic_model.atom_exclude_types, [1])
self.assertEqual(atomic_model.pair_exclude_types, [[1, 2]])

def test_notset_model_attr(self):
model_params = copy.deepcopy(model_se_e2_a)
model_params.pop("atom_exclude_types")
model_params.pop("pair_exclude_types")
model_params.pop("preset_out_bias")
self.model = get_model(model_params).to(env.DEVICE)
atomic_model = self.model.atomic_model
self.assertEqual(atomic_model.type_map, ["O", "H", "B"])
self.assertEqual(atomic_model.preset_out_bias, None)
self.assertEqual(atomic_model.atom_exclude_types, [])
self.assertEqual(atomic_model.pair_exclude_types, [])

def test_preset_wrong_len(self):
model_params = copy.deepcopy(model_se_e2_a)
model_params["preset_out_bias"] = {"energy": [None]}
with self.assertRaises(ValueError):
self.model = get_model(model_params).to(env.DEVICE)
Loading