-
Notifications
You must be signed in to change notification settings - Fork 519
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement a JAX checkpoint format. I name it `*.jax` as I don't find existing conventions. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced serialization and deserialization functionalities for JAX backend models. - Added support for the `.jax` file suffix in the backend configuration. - Enhanced attribute handling logic across various classes to ensure proper processing of non-null values. - **Bug Fixes** - Enhanced cleanup processes in the test suite to improve reliability. - **Chores** - Updated dependencies in the project configuration for better JAX compatibility. - Adjusted linting rules to accommodate JAX-related code. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
c870ccf
commit 4d50048
Showing
11 changed files
with
156 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from pathlib import ( | ||
Path, | ||
) | ||
|
||
import orbax.checkpoint as ocp | ||
|
||
from deepmd.jax.env import ( | ||
jax, | ||
nnx, | ||
) | ||
from deepmd.jax.model.model import ( | ||
BaseModel, | ||
get_model, | ||
) | ||
|
||
|
||
def deserialize_to_file(model_file: str, data: dict) -> None: | ||
"""Deserialize the dictionary to a model file. | ||
Parameters | ||
---------- | ||
model_file : str | ||
The model file to be saved. | ||
data : dict | ||
The dictionary to be deserialized. | ||
""" | ||
if model_file.endswith(".jax"): | ||
model = BaseModel.deserialize(data["model"]) | ||
model_def_script = data["model_def_script"] | ||
_, state = nnx.split(model) | ||
with ocp.Checkpointer( | ||
ocp.CompositeCheckpointHandler("state", "model_def_script") | ||
) as checkpointer: | ||
checkpointer.save( | ||
Path(model_file).absolute(), | ||
ocp.args.Composite( | ||
state=ocp.args.StandardSave(state.to_pure_dict()), | ||
model_def_script=ocp.args.JsonSave(model_def_script), | ||
), | ||
) | ||
else: | ||
raise ValueError("JAX backend only supports converting .jax directory") | ||
|
||
|
||
def serialize_from_file(model_file: str) -> dict: | ||
"""Serialize the model file to a dictionary. | ||
Parameters | ||
---------- | ||
model_file : str | ||
The model file to be serialized. | ||
Returns | ||
------- | ||
dict | ||
The serialized model data. | ||
""" | ||
if model_file.endswith(".jax"): | ||
with ocp.Checkpointer( | ||
ocp.CompositeCheckpointHandler("state", "model_def_script") | ||
) as checkpointer: | ||
data = checkpointer.restore( | ||
Path(model_file).absolute(), | ||
ocp.args.Composite( | ||
state=ocp.args.StandardRestore(), | ||
model_def_script=ocp.args.JsonRestore(), | ||
), | ||
) | ||
state = data.state | ||
|
||
# convert str "1" to int 1 key | ||
def convert_str_to_int_key(item: dict): | ||
for key, value in item.copy().items(): | ||
if isinstance(value, dict): | ||
convert_str_to_int_key(value) | ||
if key.isdigit(): | ||
item[int(key)] = item.pop(key) | ||
|
||
convert_str_to_int_key(state) | ||
|
||
model_def_script = data.model_def_script | ||
abstract_model = get_model(model_def_script) | ||
graphdef, abstract_state = nnx.split(abstract_model) | ||
abstract_state.replace_by_pure_dict(state) | ||
model = nnx.merge(graphdef, abstract_state) | ||
model_dict = model.serialize() | ||
data = { | ||
"backend": "JAX", | ||
"jax_version": jax.__version__, | ||
"model": model_dict, | ||
"model_def_script": model_def_script, | ||
"@variables": {}, | ||
} | ||
return data | ||
else: | ||
raise ValueError("JAX backend only supports converting .jax directory") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters