Skip to content

Commit

Permalink
Merge branch 'devel' into fix_gpu_ut
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd authored Jan 30, 2024
2 parents 4cd8258 + b800043 commit ca8083d
Show file tree
Hide file tree
Showing 24 changed files with 1,197 additions and 143 deletions.
6 changes: 6 additions & 0 deletions backend/find_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@ def get_tf_requirement(tf_version: str = "") -> dict:
dict
TensorFlow requirement, including cpu and gpu.
"""
if tf_version is None:
return {
"cpu": [],
"gpu": [],
"mpi": [],
}
if tf_version == "":
tf_version = os.environ.get("TENSORFLOW_VERSION", "")

Expand Down
24 changes: 17 additions & 7 deletions backend/read_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,26 @@ def get_argument_from_env() -> Tuple[str, list, list, dict, str]:
cmake_args.append("-DENABLE_IPI:BOOL=TRUE")
extra_scripts["dp_ipi"] = "deepmd.tf.entrypoints.ipi:dp_ipi"

tf_install_dir, _ = find_tensorflow()
tf_version = get_tf_version(tf_install_dir)
if tf_version == "" or Version(tf_version) >= Version("2.12"):
find_libpython_requires = []
if os.environ.get("DP_ENABLE_TENSORFLOW", "1") == "1":
tf_install_dir, _ = find_tensorflow()
tf_version = get_tf_version(tf_install_dir)
if tf_version == "" or Version(tf_version) >= Version("2.12"):
find_libpython_requires = []
else:
find_libpython_requires = ["find_libpython"]
cmake_args.extend(
[
"-DENABLE_TENSORFLOW=ON",
f"-DTENSORFLOW_VERSION={tf_version}",
f"-DTENSORFLOW_ROOT:PATH={tf_install_dir}",
]
)
else:
find_libpython_requires = ["find_libpython"]
cmake_args.append(f"-DTENSORFLOW_VERSION={tf_version}")
find_libpython_requires = []
cmake_args.append("-DENABLE_TENSORFLOW=OFF")
tf_version = None

cmake_args = [
f"-DTENSORFLOW_ROOT:PATH={tf_install_dir}",
"-DBUILD_PY_IF:BOOL=TRUE",
*cmake_args,
]
Expand Down
4 changes: 4 additions & 0 deletions deepmd/model_format/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from .env_mat import (
EnvMat,
)
from .fitting import (
InvarFitting,
)
from .network import (
EmbeddingNet,
FittingNet,
Expand Down Expand Up @@ -34,6 +37,7 @@
)

__all__ = [
"InvarFitting",
"DescrptSeA",
"EnvMat",
"make_multilayer_network",
Expand Down
Loading

0 comments on commit ca8083d

Please sign in to comment.