Skip to content

Commit

Permalink
feat: add get_model classmethod to BaseModel (deepmodeling#4002)
Browse files Browse the repository at this point in the history
Fix deepmodeling#3968. External and new models can implement this method (if
different from default) without changing the old `get_model` methods
(which cannot be done by a plugin).

Note: I don't modify old `get_model` methods in this PR.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Introduced a new method for model instantiation that enhances
flexibility in parameter configuration.
- Improved the model retrieval process to support dynamic model
selection based on specified types.
  
- **Bug Fixes**
- Enhanced control flow to ensure correct model type selection,
addressing potential issues with model retrieval.

- **Refactor**
- Updated existing model retrieval functions to streamline logic and
improve maintainability.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored and Mathieu Taillefumier committed Sep 18, 2024
1 parent e5ee94d commit 0c3ebe3
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 8 deletions.
25 changes: 25 additions & 0 deletions deepmd/dpmodel/model/base_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import inspect
import json
from abc import (
ABC,
abstractmethod,
Expand Down Expand Up @@ -193,6 +194,30 @@ def update_sel(
cls = cls.get_class_by_type(model_type)
return cls.update_sel(train_data, type_map, local_jdata)

@classmethod
def get_model(cls, model_params: dict) -> "BaseBaseModel":
"""Get the model by the parameters.
By default, all the parameters are directly passed to the constructor.
If not, override this method.
Parameters
----------
model_params : dict
The model parameters
Returns
-------
BaseBaseModel
The model
"""
model_params_old = model_params.copy()
model_params = model_params.copy()
model_params.pop("type", None)
model = cls(**model_params)
model.model_def_script = json.dumps(model_params_old)
return model

return BaseBaseModel


Expand Down
13 changes: 10 additions & 3 deletions deepmd/dpmodel/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from deepmd.dpmodel.fitting.ener_fitting import (
EnergyFittingNet,
)
from deepmd.dpmodel.model.base_model import (
BaseModel,
)
from deepmd.dpmodel.model.ener_model import (
EnergyModel,
)
Expand Down Expand Up @@ -93,7 +96,11 @@ def get_model(data: dict):
data : dict
The data to construct the model.
"""
if "spin" in data:
return get_spin_model(data)
model_type = data.get("type", "standard")
if model_type == "standard":
if "spin" in data:
return get_spin_model(data)
else:
return get_standard_model(data)
else:
return get_standard_model(data)
return BaseModel.get_class_by_type(model_type).get_model(data)
14 changes: 9 additions & 5 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,16 @@ def get_standard_model(model_params):


def get_model(model_params):
if "spin" in model_params:
return get_spin_model(model_params)
elif "use_srtab" in model_params:
return get_zbl_model(model_params)
model_type = model_params.get("type", "standard")
if model_type == "standard":
if "spin" in model_params:
return get_spin_model(model_params)
elif "use_srtab" in model_params:
return get_zbl_model(model_params)
else:
return get_standard_model(model_params)
else:
return get_standard_model(model_params)
return BaseModel.get_class_by_type(model_type).get_model(model_params)


__all__ = [
Expand Down

0 comments on commit 0c3ebe3

Please sign in to comment.