Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax): SavedModel C++ interface (including DPA-2 supports) #4307

Merged
merged 81 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
147400e
feat: saved model C++ interface
njzjz Nov 4, 2024
8c6d522
model
njzjz Nov 4, 2024
140f3e1
update test data
njzjz Nov 4, 2024
a0b8074
need CPU model
njzjz Nov 4, 2024
e6bf59f
skip memory check
njzjz Nov 4, 2024
6c10e8e
fix
njzjz Nov 4, 2024
2aa6deb
Apply suggestions from code review
njzjz Nov 4, 2024
f16dd92
Update source/api_cc/src/DeepPotJAX.cc
njzjz Nov 4, 2024
297ae26
debug memory leak
njzjz Nov 4, 2024
e64e06a
add LAMMPS test
njzjz Nov 4, 2024
8fefce8
fix memory leak in add_input
njzjz Nov 4, 2024
261c7bd
pass reference
njzjz Nov 4, 2024
4d5ccc5
delete function and retvals
njzjz Nov 4, 2024
d365bbc
Merge branch 'savedmodel-cxx-debug-mem' into savedmodel-cxx
njzjz Nov 4, 2024
21fc045
no need to skip the test
njzjz Nov 4, 2024
660171e
Merge branch 'devel' into savedmodel-cxx
njzjz Nov 4, 2024
d552821
Merge remote-tracking branch 'origin/devel' into savedmodel-cxx
njzjz Nov 4, 2024
0461248
add limitation
njzjz Nov 4, 2024
f26f3fe
fix tf string parse
njzjz Nov 4, 2024
713d065
Update source/api_cc/tests/test_deeppot_jax.cc
njzjz Nov 4, 2024
ccb182d
cast void*
njzjz Nov 4, 2024
8ccead6
handle zero atom
njzjz Nov 5, 2024
904042d
Merge branch 'devel' into jax-cxx-dpa2
njzjz Nov 5, 2024
0f9d5c5
feat(jax): DPA-2 for LAMMPS
njzjz Nov 5, 2024
bad564b
use the cpu model
njzjz Nov 5, 2024
2b165d7
fix function name
njzjz Nov 5, 2024
e717ba3
fix typos
njzjz Nov 5, 2024
f075075
nloc_real -> nall_real
njzjz Nov 6, 2024
58dcf2b
document limation
njzjz Nov 6, 2024
d93d13a
Merge branch 'devel' into jax-cxx-dpa1
Nov 9, 2024
232f7cd
fix(tf): fix normalize when compressing a model converted from other …
Nov 10, 2024
ce9ee61
apply padding method
Nov 10, 2024
6b10eb7
update model
njzjz Nov 11, 2024
afc71cb
Merge commit 'ce9ee61e71b83d2c682522706f98955dfecea98a' into jax-cxx-…
njzjz Nov 11, 2024
e1a2b55
Merge remote-tracking branch 'origin/devel' into reformat-jax-cxx
njzjz Nov 11, 2024
649f98e
update base class
njzjz Nov 11, 2024
1cad0b2
perhaps PADDING_FACTOR doesn't need so much
njzjz Nov 11, 2024
239d186
use max size
njzjz Nov 11, 2024
37c8739
bump API version
njzjz Nov 11, 2024
b863c79
update model
njzjz Nov 11, 2024
95ad9d0
update model
njzjz Nov 11, 2024
b6d039f
Revert "use max size"
njzjz Nov 11, 2024
5e2ea67
test
njzjz Nov 11, 2024
72a23d2
debug
njzjz Nov 11, 2024
edc4445
add all functions
njzjz Nov 11, 2024
b0808f1
Reapply "use max size"
njzjz Nov 11, 2024
458be34
Revert "debug"
njzjz Nov 11, 2024
87908c3
Revert "test"
njzjz Nov 11, 2024
3a0ca2d
Revert "update model"
njzjz Nov 11, 2024
1863b27
Revert "update model"
njzjz Nov 11, 2024
eb549e5
cast type
njzjz Nov 11, 2024
8a154bd
update model
njzjz Nov 11, 2024
4dab4fb
bugfix
njzjz Nov 11, 2024
c4f08c8
fix OOM issue
njzjz Nov 11, 2024
ef70135
no nlist interface
njzjz Nov 11, 2024
be02814
fix skip
njzjz Nov 11, 2024
3c46f37
try to reduce memory
njzjz Nov 12, 2024
49f57bc
fix skip tests
njzjz Nov 12, 2024
e8a99f4
also skip lammps dpa-2 tests for CUDA
njzjz Nov 12, 2024
8f83a28
should be fw
njzjz Nov 12, 2024
8c05d54
Revert "should be fw"
njzjz Nov 12, 2024
88be054
Revert "try to reduce memory"
njzjz Nov 12, 2024
5cfc83c
Revert "fix OOM issue"
njzjz Nov 12, 2024
9af5267
set --clean-durations
njzjz Nov 12, 2024
01567d6
Merge branch 'devel' into savedmodel-cxx
njzjz Nov 12, 2024
dc4a9d7
Merge remote-tracking branch 'origin/devel' into savedmodel-cxx
njzjz Nov 12, 2024
1234489
add example
njzjz Nov 12, 2024
86d1b7a
convert models at runtime
njzjz Nov 12, 2024
93cc440
add script path
njzjz Nov 12, 2024
546f7dc
revert strict=False
njzjz Nov 12, 2024
0d51bcc
revert .gitignore
njzjz Nov 12, 2024
fc1f90d
prefer cuda's cudnn
njzjz Nov 13, 2024
9447603
bump cuda version
njzjz Nov 13, 2024
6d5b45a
debug
njzjz Nov 13, 2024
e569ed9
fix docker name
njzjz Nov 13, 2024
1b3fd5e
set allow_growth to True
njzjz Nov 13, 2024
9d95778
fix compile error
njzjz Nov 13, 2024
09efdd3
fix typo
njzjz Nov 13, 2024
39f357c
call TFE_ContextOptionsSetConfig
njzjz Nov 13, 2024
cfff834
fix config
njzjz Nov 13, 2024
ca02625
Revert "debug"
njzjz Nov 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_cc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
TF_INTER_OP_PARALLELISM_THREADS: 1
LMP_CXX11_ABI_0: 1
CMAKE_GENERATOR: Ninja
CXXFLAGS: ${{ matrix.check_memleak && '-fsanitize=leak' || '' }}
CXXFLAGS: ${{ matrix.check_memleak && '-fsanitize=leak -DENABLE_SANITIZE=1' || '' }}
# test lammps
- run: |
export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
Expand Down
3 changes: 2 additions & 1 deletion doc/backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ While `.pth` and `.pt` are the same in the PyTorch package, they have different

[JAX](https://jax.readthedocs.io/) 0.4.33 (which requires Python 3.10 or above) or above is required.
Both `.xlo` and `.jax` are customized format extensions defined in DeePMD-kit, since JAX has no convention for file extensions.
Currently, this backend is developed actively, and has no support for training and the C++ interface.
Only the `.savedmodel` format supports C++ inference, which needs the TensorFlow C++ interface.
Currently, this backend is developed actively, and has no support for training.
njzjz marked this conversation as resolved.
Show resolved Hide resolved
njzjz marked this conversation as resolved.
Show resolved Hide resolved

### DP {{ dpmodel_icon }}

Expand Down
10 changes: 6 additions & 4 deletions doc/install/install-from-source.md
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,9 @@ If one does not need to use DeePMD-kit with LAMMPS or i-PI, then the python inte

::::{tab-set}

:::{tab-item} TensorFlow {{ tensorflow_icon }}
:::{tab-item} TensorFlow {{ tensorflow_icon }} / JAX {{ jax_icon }}

The C++ interfaces of both TensorFlow and JAX backends are based on the TensorFlow C++ library.
njzjz marked this conversation as resolved.
Show resolved Hide resolved

Since TensorFlow 2.12, TensorFlow C++ library (`libtensorflow_cc`) is packaged inside the Python library. Thus, you can skip building TensorFlow C++ library manually. If that does not work for you, you can still build it manually.

Expand Down Expand Up @@ -338,7 +340,7 @@ We recommend using [conda packages](https://docs.deepmodeling.org/faq/conda.html

::::{tab-set}

:::{tab-item} TensorFlow {{ tensorflow_icon }}
:::{tab-item} TensorFlow {{ tensorflow_icon }} / JAX {{ jax_icon }}

I assume you have activated the TensorFlow Python environment and want to install DeePMD-kit into path `$deepmd_root`, then execute CMake

Expand Down Expand Up @@ -375,7 +377,7 @@ One may add the following CMake variables to `cmake` using the [`-D <var>=<value

**Type**: `BOOL` (`ON`/`OFF`), Default: `OFF`

{{ tensorflow_icon }} Whether building the TensorFlow backend.
{{ tensorflow_icon }} {{ jax_icon }} Whether building the TensorFlow backend and the JAX backend.
njzjz marked this conversation as resolved.
Show resolved Hide resolved

:::

Expand All @@ -391,7 +393,7 @@ One may add the following CMake variables to `cmake` using the [`-D <var>=<value

**Type**: `PATH`

{{ tensorflow_icon }} The Path to TensorFlow's C++ interface.
{{ tensorflow_icon }} {{ jax_icon }} The Path to TensorFlow's C++ interface.

:::

Expand Down
249 changes: 249 additions & 0 deletions source/api_cc/include/DeepPotJAX.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
// SPDX-License-Identifier: LGPL-3.0-or-later
#pragma once

#include <tensorflow/c/c_api.h>
#include <tensorflow/c/eager/c_api.h>

#include "DeepPot.h"
#include "common.h"
#include "neighbor_list.h"

namespace deepmd {
/**
* @brief TensorFlow implementation for Deep Potential.
**/
class DeepPotJAX : public DeepPotBase {
public:
/**
* @brief DP constructor without initialization.
**/
DeepPotJAX();
virtual ~DeepPotJAX();
/**
* @brief DP constructor with initialization.
* @param[in] model The name of the frozen model file.
* @param[in] gpu_rank The GPU rank. Default is 0.
* @param[in] file_content The content of the model file. If it is not empty,
*DP will read from the string instead of the file.
**/
DeepPotJAX(const std::string& model,
const int& gpu_rank = 0,
const std::string& file_content = "");
/**
* @brief Initialize the DP.
* @param[in] model The name of the frozen model file.
* @param[in] gpu_rank The GPU rank. Default is 0.
* @param[in] file_content The content of the model file. If it is not empty,
*DP will read from the string instead of the file.
**/
void init(const std::string& model,
const int& gpu_rank = 0,
const std::string& file_content = "");
/**
* @brief Get the cutoff radius.
* @return The cutoff radius.
**/
double cutoff() const {
assert(inited);
return rcut;
};
njzjz marked this conversation as resolved.
Show resolved Hide resolved
/**
* @brief Get the number of types.
* @return The number of types.
**/
int numb_types() const {

Check warning on line 54 in source/api_cc/include/DeepPotJAX.h

View check run for this annotation

Codecov / codecov/patch

source/api_cc/include/DeepPotJAX.h#L54

Added line #L54 was not covered by tests
assert(inited);
return ntypes;

Check warning on line 56 in source/api_cc/include/DeepPotJAX.h

View check run for this annotation

Codecov / codecov/patch

source/api_cc/include/DeepPotJAX.h#L56

Added line #L56 was not covered by tests
};
/**
* @brief Get the number of types with spin.
* @return The number of types with spin.
**/
int numb_types_spin() const {

Check warning on line 62 in source/api_cc/include/DeepPotJAX.h

View check run for this annotation

Codecov / codecov/patch

source/api_cc/include/DeepPotJAX.h#L62

Added line #L62 was not covered by tests
assert(inited);
return 0;

Check warning on line 64 in source/api_cc/include/DeepPotJAX.h

View check run for this annotation

Codecov / codecov/patch

source/api_cc/include/DeepPotJAX.h#L64

Added line #L64 was not covered by tests
};
/**
* @brief Get the dimension of the frame parameter.
* @return The dimension of the frame parameter.
**/
int dim_fparam() const {

Check warning on line 70 in source/api_cc/include/DeepPotJAX.h

View check run for this annotation

Codecov / codecov/patch

source/api_cc/include/DeepPotJAX.h#L70

Added line #L70 was not covered by tests
assert(inited);
return dfparam;

Check warning on line 72 in source/api_cc/include/DeepPotJAX.h

View check run for this annotation

Codecov / codecov/patch

source/api_cc/include/DeepPotJAX.h#L72

Added line #L72 was not covered by tests
};
/**
* @brief Get the dimension of the atomic parameter.
* @return The dimension of the atomic parameter.
**/
int dim_aparam() const {

Check warning on line 78 in source/api_cc/include/DeepPotJAX.h

View check run for this annotation

Codecov / codecov/patch

source/api_cc/include/DeepPotJAX.h#L78

Added line #L78 was not covered by tests
assert(inited);
return daparam;

Check warning on line 80 in source/api_cc/include/DeepPotJAX.h

View check run for this annotation

Codecov / codecov/patch

source/api_cc/include/DeepPotJAX.h#L80

Added line #L80 was not covered by tests
};
/**
* @brief Get the type map (element name of the atom types) of this model.
* @param[out] type_map The type map of this model.
**/
void get_type_map(std::string& type_map);

/**
* @brief Get whether the atom dimension of aparam is nall instead of fparam.
* @param[out] aparam_nall whether the atom dimension of aparam is nall
*instead of fparam.
**/
bool is_aparam_nall() const {

Check warning on line 93 in source/api_cc/include/DeepPotJAX.h

View check run for this annotation

Codecov / codecov/patch

source/api_cc/include/DeepPotJAX.h#L93

Added line #L93 was not covered by tests
assert(inited);
return false;

Check warning on line 95 in source/api_cc/include/DeepPotJAX.h

View check run for this annotation

Codecov / codecov/patch

source/api_cc/include/DeepPotJAX.h#L95

Added line #L95 was not covered by tests
};

// forward to template class
void computew(std::vector<double>& ener,
std::vector<double>& force,
std::vector<double>& virial,
std::vector<double>& atom_energy,
std::vector<double>& atom_virial,
const std::vector<double>& coord,
const std::vector<int>& atype,
const std::vector<double>& box,
const std::vector<double>& fparam,
const std::vector<double>& aparam,
const bool atomic);
void computew(std::vector<double>& ener,
std::vector<float>& force,
std::vector<float>& virial,
std::vector<float>& atom_energy,
std::vector<float>& atom_virial,
const std::vector<float>& coord,
const std::vector<int>& atype,
const std::vector<float>& box,
const std::vector<float>& fparam,
const std::vector<float>& aparam,
const bool atomic);
void computew(std::vector<double>& ener,
std::vector<double>& force,
std::vector<double>& virial,
std::vector<double>& atom_energy,
std::vector<double>& atom_virial,
const std::vector<double>& coord,
const std::vector<int>& atype,
const std::vector<double>& box,
const int nghost,
const InputNlist& inlist,
const int& ago,
const std::vector<double>& fparam,
const std::vector<double>& aparam,
const bool atomic);
void computew(std::vector<double>& ener,
std::vector<float>& force,
std::vector<float>& virial,
std::vector<float>& atom_energy,
std::vector<float>& atom_virial,
const std::vector<float>& coord,
const std::vector<int>& atype,
const std::vector<float>& box,
const int nghost,
const InputNlist& inlist,
const int& ago,
const std::vector<float>& fparam,
const std::vector<float>& aparam,
const bool atomic);
void computew_mixed_type(std::vector<double>& ener,
std::vector<double>& force,
std::vector<double>& virial,
std::vector<double>& atom_energy,
std::vector<double>& atom_virial,
const int& nframes,
const std::vector<double>& coord,
const std::vector<int>& atype,
const std::vector<double>& box,
const std::vector<double>& fparam,
const std::vector<double>& aparam,
const bool atomic);
void computew_mixed_type(std::vector<double>& ener,
std::vector<float>& force,
std::vector<float>& virial,
std::vector<float>& atom_energy,
std::vector<float>& atom_virial,
const int& nframes,
const std::vector<float>& coord,
const std::vector<int>& atype,
const std::vector<float>& box,
const std::vector<float>& fparam,
const std::vector<float>& aparam,
const bool atomic);

private:
bool inited;
// device
std::string device;
// the cutoff radius
double rcut;
// the number of types
int ntypes;
// the dimension of the frame parameter
int dfparam;
// the dimension of the atomic parameter
int daparam;
// type map
std::string type_map;
// sel
std::vector<int64_t> sel;
// number of neighbors
int nnei;
/** TF C API objects.
* @{
*/
TF_Graph* graph;
TF_Status* status;
TF_Session* session;
TF_SessionOptions* sessionopts;
TFE_ContextOptions* ctx_opts;
TFE_Context* ctx;
std::vector<TF_Function*> func_vector;
/**
* @}
*/
njzjz marked this conversation as resolved.
Show resolved Hide resolved
// neighbor list data
NeighborListData nlist_data;
/**
* @brief Evaluate the energy, force, virial, atomic energy, and atomic virial
*by using this DP.
* @param[out] ener The system energy.
* @param[out] force The force on each atom.
* @param[out] virial The virial.
* @param[out] atom_energy The atomic energy.
* @param[out] atom_virial The atomic virial.
* @param[in] coord The coordinates of atoms. The array should be of size
*nframes x natoms x 3.
* @param[in] atype The atom types. The list should contain natoms ints.
* @param[in] box The cell of the region. The array should be of size nframes
*x 9.
* @param[in] nghost The number of ghost atoms.
* @param[in] lmp_list The input neighbour list.
* @param[in] ago Update the internal neighbour list if ago is 0.
* @param[in] fparam The frame parameter. The array can be of size :
* nframes x dim_fparam.
* dim_fparam. Then all frames are assumed to be provided with the same
*fparam.
* @param[in] aparam The atomic parameter The array can be of size :
* nframes x natoms x dim_aparam.
* natoms x dim_aparam. Then all frames are assumed to be provided with the
*same aparam.
* @param[in] atomic Whether to compute atomic energy and virial.
**/
template <typename VALUETYPE>
void compute(std::vector<ENERGYTYPE>& ener,
std::vector<VALUETYPE>& force,
std::vector<VALUETYPE>& virial,
std::vector<VALUETYPE>& atom_energy,
std::vector<VALUETYPE>& atom_virial,
const std::vector<VALUETYPE>& coord,
const std::vector<int>& atype,
const std::vector<VALUETYPE>& box,
const int nghost,
const InputNlist& lmp_list,
const int& ago,
const std::vector<VALUETYPE>& fparam,
const std::vector<VALUETYPE>& aparam,
const bool atomic);
};
} // namespace deepmd
2 changes: 1 addition & 1 deletion source/api_cc/include/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
namespace deepmd {

typedef double ENERGYTYPE;
enum DPBackend { TensorFlow, PyTorch, Paddle, Unknown };
enum DPBackend { TensorFlow, PyTorch, Paddle, JAX, Unknown };
njzjz marked this conversation as resolved.
Show resolved Hide resolved
njzjz marked this conversation as resolved.
Show resolved Hide resolved

struct NeighborListData {
/// Array stores the core region atom's index
Expand Down
12 changes: 12 additions & 0 deletions source/api_cc/src/DeepPot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "AtomMap.h"
#include "common.h"
#ifdef BUILD_TENSORFLOW
#include "DeepPotJAX.h"
#include "DeepPotTF.h"
#endif
#ifdef BUILD_PYTORCH
Expand Down Expand Up @@ -41,6 +42,9 @@ void DeepPot::init(const std::string& model,
backend = deepmd::DPBackend::PyTorch;
} else if (model.length() >= 3 && model.substr(model.length() - 3) == ".pb") {
backend = deepmd::DPBackend::TensorFlow;
} else if (model.length() >= 11 &&
model.substr(model.length() - 11) == ".savedmodel") {
backend = deepmd::DPBackend::JAX;
} else {
throw deepmd::deepmd_exception("Unsupported model file format");
}
Expand All @@ -58,6 +62,14 @@ void DeepPot::init(const std::string& model,
#endif
} else if (deepmd::DPBackend::Paddle == backend) {
throw deepmd::deepmd_exception("PaddlePaddle backend is not supported yet");
} else if (deepmd::DPBackend::JAX == backend) {
#ifdef BUILD_TENSORFLOW
dp = std::make_shared<deepmd::DeepPotJAX>(model, gpu_rank, file_content);
#else
throw deepmd::deepmd_exception(
"TensorFlow backend is not built, which is used to load JAX2TF "
"SavedModels");
#endif
} else {
throw deepmd::deepmd_exception("Unknown file type");
}
Expand Down
Loading