From 3d0bab67ce90a5da9217a3f61a570c7ec18efd3c Mon Sep 17 00:00:00 2001 From: Raimondas Galvelis Date: Mon, 26 Apr 2021 22:44:02 +0200 Subject: [PATCH] A better PyTorch wrapper (#19) * Move CustomANISymmetryFunctions construction to ANISymmetryFunctionsOp * Move CustomANISymmetryFunctions construction to TorchANISymmetryFunctions.forward * Move CustomANISymmetryFunctions construction to TorchANISymmetryFunctions.forward * Create NNPOps::ANISymmetryFunctions namespace * Simplify names * Simplify types * Fix typo * Implement Holder::is_initialized * Don't use Optional[Holder] * Fix serializaton * Update the benckmark * Update the build instructions * Fix the constructor --- pytorch/BenchmarkTorchANISymmetryFunctions.py | 4 +- pytorch/README.md | 33 ++- pytorch/SymmetryFunctions.cpp | 198 +++++++++--------- pytorch/SymmetryFunctions.py | 23 +- pytorch/environment.yml | 12 ++ 5 files changed, 146 insertions(+), 124 deletions(-) create mode 100644 pytorch/environment.yml diff --git a/pytorch/BenchmarkTorchANISymmetryFunctions.py b/pytorch/BenchmarkTorchANISymmetryFunctions.py index a2d5e7a..6299c73 100644 --- a/pytorch/BenchmarkTorchANISymmetryFunctions.py +++ b/pytorch/BenchmarkTorchANISymmetryFunctions.py @@ -40,7 +40,7 @@ sum_aev.backward() grad = positions.grad.clone() -N = 40000 +N = 100000 start = time.time() for _ in range(N): aev = symmFunc(speciesPositions).aevs @@ -55,7 +55,5 @@ aev_error = torch.max(torch.abs(aev - aev_ref)) grad_error = torch.max(torch.abs(grad - grad_ref)) -print(aev_error) -print(grad_error) assert aev_error < 0.0002 assert grad_error < 0.007 \ No newline at end of file diff --git a/pytorch/README.md b/pytorch/README.md index c13b3a7..d5c63e6 100644 --- a/pytorch/README.md +++ b/pytorch/README.md @@ -44,39 +44,32 @@ print(energy, forces) ### Build & install -- Crate a *Conda* environment +- Get the source code ```bash -$ conda create -n nnpops \ - -c pytorch \ - -c conda-forge \ - cmake \ - git \ - gxx_linux-64 \ - make \ - mdtraj \ - pytest \ - python=3.8 \ - pytorch=1.6 \ - torchani=2.2 -$ conda activate nnpops +$ git clone https://github.com/openmm/NNPOps.git ``` -- Get the source code + +- Crate a *Conda* environment ```bash -$ git clone https://github.com/peastman/NNPOps.git +$ cd NNPOps +$ conda create -f pytorch/environment.yml +$ conda activate nnpops ``` + - Configure, build, and install ```bash $ mkdir build $ cd build -$ cmake ../NNPOps/pytorch \ +$ cmake ../pytorch \ -DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc \ -DCMAKE_CUDA_HOST_COMPILER=$CXX \ - -DTorch_DIR=$CONDA_PREFIX/lib/python3.8/site-packages/torch/share/cmake/Torch \ + -DTorch_DIR=$CONDA_PREFIX/lib/python3.9/site-packages/torch/share/cmake/Torch \ -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX $ make install ``` -- Optional: run tests +- Optional: run tests and benchmarks ```bash -$ cd ../NNPOps/pytorch +$ cd ../pytorch $ pytest TestSymmetryFunctions.py +$ python BenchmarkTorchANISymmetryFunctions.py ``` \ No newline at end of file diff --git a/pytorch/SymmetryFunctions.cpp b/pytorch/SymmetryFunctions.cpp index 6b1d1f8..b4284cc 100644 --- a/pytorch/SymmetryFunctions.cpp +++ b/pytorch/SymmetryFunctions.cpp @@ -32,31 +32,52 @@ throw std::runtime_error(std::string("Encountered error ")+cudaGetErrorName(result)+" at "+__FILE__+":"+std::to_string(__LINE__));\ } -class CustomANISymmetryFunctions : public torch::CustomClassHolder { +namespace NNPOps { +namespace ANISymmetryFunctions { + +class Holder; +using std::vector; +using HolderPtr = torch::intrusive_ptr; +using torch::Tensor; +using torch::optional; +using Context = torch::autograd::AutogradContext; +using torch::autograd::tensor_list; + +class Holder : public torch::CustomClassHolder { public: - CustomANISymmetryFunctions(int64_t numSpecies_, - double Rcr, - double Rca, - const std::vector& EtaR, - const std::vector& ShfR, - const std::vector& EtaA, - const std::vector& Zeta, - const std::vector& ShfA, - const std::vector& ShfZ, - const std::vector& atomSpecies_, - const torch::Tensor& positions) : torch::CustomClassHolder() { + + // Constructor for an uninitialized object + // Note: this is need for serialization + Holder() : torch::CustomClassHolder() {}; + + Holder(int64_t numSpecies_, + double Rcr, + double Rca, + const vector& EtaR, + const vector& ShfR, + const vector& EtaA, + const vector& Zeta, + const vector& ShfA, + const vector& ShfZ, + const vector& atomSpecies_, + const Tensor& positions) : torch::CustomClassHolder() { + + // Construct an uninitialized object + // Note: this is needed for Python bindings + if (numSpecies_ == 0) + return; tensorOptions = torch::TensorOptions().device(positions.device()); // Data type of float by default int numAtoms = atomSpecies_.size(); int numSpecies = numSpecies_; - const std::vector atomSpecies(atomSpecies_.begin(), atomSpecies_.end()); + const vector atomSpecies(atomSpecies_.begin(), atomSpecies_.end()); - std::vector radialFunctions; + vector radialFunctions; for (const float eta: EtaR) for (const float rs: ShfR) radialFunctions.push_back({eta, rs}); - std::vector angularFunctions; + vector angularFunctions; for (const float eta: EtaA) for (const float zeta: Zeta) for (const float rs: ShfA) @@ -77,11 +98,11 @@ class CustomANISymmetryFunctions : public torch::CustomClassHolder { positionsGrad = torch::empty({numAtoms, 3}, tensorOptions); }; - torch::autograd::tensor_list forward(const torch::Tensor& positions_, const torch::optional& periodicBoxVectors_) { + tensor_list forward(const Tensor& positions_, const optional& periodicBoxVectors_) { - const torch::Tensor positions = positions_.to(tensorOptions); + const Tensor positions = positions_.to(tensorOptions); - torch::Tensor periodicBoxVectors; + Tensor periodicBoxVectors; float* periodicBoxVectorsPtr = nullptr; if (periodicBoxVectors_) { periodicBoxVectors = periodicBoxVectors_->to(tensorOptions); @@ -93,99 +114,86 @@ class CustomANISymmetryFunctions : public torch::CustomClassHolder { return {radial, angular}; }; - torch::Tensor backward(const torch::autograd::tensor_list& grads) { + Tensor backward(const tensor_list& grads) { - const torch::Tensor radialGrad = grads[0].clone(); - const torch::Tensor angularGrad = grads[1].clone(); + const Tensor radialGrad = grads[0].clone(); + const Tensor angularGrad = grads[1].clone(); symFunc->backprop(radialGrad.data_ptr(), angularGrad.data_ptr(), positionsGrad.data_ptr()); return positionsGrad; - } + }; + + bool is_initialized() { + return bool(symFunc); + }; private: torch::TensorOptions tensorOptions; - std::shared_ptr symFunc; - torch::Tensor radial; - torch::Tensor angular; - torch::Tensor positionsGrad; + std::shared_ptr<::ANISymmetryFunctions> symFunc; + Tensor radial; + Tensor angular; + Tensor positionsGrad; }; -class GradANISymmetryFunction : public torch::autograd::Function { +class AutogradFunctions : public torch::autograd::Function { public: - static torch::autograd::tensor_list forward(torch::autograd::AutogradContext *ctx, - int64_t numSpecies, - double Rcr, - double Rca, - const std::vector& EtaR, - const std::vector& ShfR, - const std::vector& EtaA, - const std::vector& Zeta, - const std::vector& ShfA, - const std::vector& ShfZ, - const std::vector& atomSpecies, - const torch::Tensor& positions, - const torch::optional& periodicBoxVectors) { - - const auto symFunc = torch::intrusive_ptr::make( - numSpecies, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, atomSpecies, positions); - ctx->saved_data["symFunc"] = symFunc; - - return symFunc->forward(positions, periodicBoxVectors); + static tensor_list forward(Context *ctx, + const HolderPtr& holder, + const Tensor& positions, + const optional& periodicBoxVectors) { + + ctx->saved_data["holder"] = holder; + + return holder->forward(positions, periodicBoxVectors); }; - static torch::autograd::tensor_list backward(torch::autograd::AutogradContext *ctx, const torch::autograd::tensor_list& grads) { - - const auto symFunc = ctx->saved_data["symFunc"].toCustomClass(); - torch::Tensor positionsGrad = symFunc->backward(grads); - ctx->saved_data.erase("symFunc"); - - return { torch::Tensor(), // numSpecies - torch::Tensor(), // Rcr - torch::Tensor(), // Rca - torch::Tensor(), // EtaR - torch::Tensor(), // ShfR - torch::Tensor(), // EtaA - torch::Tensor(), // Zeta - torch::Tensor(), // ShfA - torch::Tensor(), // ShfZ - torch::Tensor(), // atomSpecies - positionsGrad, // positions - torch::Tensor()}; // periodicBoxVectors + static tensor_list backward(Context *ctx, const tensor_list& grads) { + + const auto holder = ctx->saved_data["holder"].toCustomClass(); + Tensor positionsGrad = holder->backward(grads); + ctx->saved_data.erase("holder"); + + return { Tensor(), // holder + positionsGrad, // positions + Tensor() }; // periodicBoxVectors }; }; -static torch::autograd::tensor_list ANISymmetryFunctionsOp(int64_t numSpecies, - double Rcr, - double Rca, - const std::vector& EtaR, - const std::vector& ShfR, - const std::vector& EtaA, - const std::vector& Zeta, - const std::vector& ShfA, - const std::vector& ShfZ, - const std::vector& atomSpecies, - const torch::Tensor& positions, - const torch::optional& periodicBoxVectors) { - - return GradANISymmetryFunction::apply(numSpecies, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, atomSpecies, positions, periodicBoxVectors); +tensor_list operation(const optional& holder, + const Tensor& positions, + const optional& periodicBoxVectors) { + + return AutogradFunctions::apply(*holder, positions, periodicBoxVectors); +} + +TORCH_LIBRARY(NNPOpsANISymmetryFunctions, m) { + m.class_("Holder") + .def(torch::init&, // EtaR + const vector&, // ShfR + const vector&, // EtaA + const vector&, // Zeta + const vector&, // ShfA + const vector&, // ShfZ + const vector&, // atomSpecies + const Tensor&>()) // positions + .def("forward", &Holder::forward) + .def("backward", &Holder::backward) + .def("is_initialized", &Holder::is_initialized) + .def_pickle( + // __getstate__ + // Note: nothing is during serialization + [](const HolderPtr& self) -> int64_t { return 0; }, + // __setstate__ + // Note: a new uninitialized object is create during deserialization + [](int64_t state) -> HolderPtr { return HolderPtr::make(); } + ); + m.def("operation", operation); } -TORCH_LIBRARY(NNPOps, m) { - m.class_("CustomANISymmetryFunctions") - .def(torch::init&, // EtaR - const std::vector&, // ShfR - const std::vector&, // EtaA - const std::vector&, // Zeta - const std::vector&, // ShfA - const std::vector&, // ShfZ - const std::vector&, // atomSpecies - const torch::Tensor&>()) // positions - .def("forward", &CustomANISymmetryFunctions::forward) - .def("backward", &CustomANISymmetryFunctions::backward); - m.def("ANISymmetryFunctions", ANISymmetryFunctionsOp); -} \ No newline at end of file +} // namespace ANISymmetryFunctions +} // namespace NNPOps \ No newline at end of file diff --git a/pytorch/SymmetryFunctions.py b/pytorch/SymmetryFunctions.py index 224b793..af3bb84 100644 --- a/pytorch/SymmetryFunctions.py +++ b/pytorch/SymmetryFunctions.py @@ -29,6 +29,10 @@ from torchani.aev import SpeciesAEV torch.ops.load_library(os.path.join(os.path.dirname(__file__), 'libNNPOpsPyTorch.so')) +torch.classes.load_library(os.path.join(os.path.dirname(__file__), 'libNNPOpsPyTorch.so')) + +Holder = torch.classes.NNPOpsANISymmetryFunctions.Holder +operation = torch.ops.NNPOpsANISymmetryFunctions.operation class TorchANISymmetryFunctions(torch.nn.Module): """Optimized TorchANI symmetry functions @@ -66,7 +70,6 @@ def __init__(self, symmFunc: torchani.AEVComputer): Arguments: symmFunc: the instance of torchani.AEVComputer (https://aiqm.github.io/torchani/api.html#torchani.AEVComputer) """ - super().__init__() self.numSpecies = symmFunc.num_species @@ -79,6 +82,10 @@ def __init__(self, symmFunc: torchani.AEVComputer): self.ShfA = symmFunc.ShfA[0, 0, :, 0].tolist() self.ShfZ = symmFunc.ShfZ[0, 0, 0, :].tolist() + # Create an uninitialized holder + self.holder = Holder(0, 0, 0, [], [] , [] , [], [] , [], [], Tensor()) + assert not self.holder.is_initialized() + self.triu_index = torch.tensor([0]) # A dummy variable to make TorchScript happy ;) def forward(self, speciesAndPositions: Tuple[Tensor, Tensor], @@ -100,7 +107,6 @@ def forward(self, speciesAndPositions: Tuple[Tensor, Tensor], species, positions = speciesAndPositions if species.shape[0] != 1: raise ValueError('Batched molecule computation is not supported') - species_: List[int] = species[0].tolist() # Explicit type casting for TorchScript if species.shape + (3,) != positions.shape: raise ValueError('Inconsistent shapes of "species" and "positions"') if cell is not None: @@ -113,10 +119,15 @@ def forward(self, speciesAndPositions: Tuple[Tensor, Tensor], if pbc_ != [True, True, True]: raise ValueError('Only fully periodic systems are supported, i.e. pbc = [True, True, True]') - symFunc = torch.ops.NNPOps.ANISymmetryFunctions - radial, angular = symFunc(self.numSpecies, self.Rcr, self.Rca, self.EtaR, self.ShfR, - self.EtaA, self.Zeta, self.ShfA, self.ShfZ, - species_, positions[0], cell) + if not self.holder.is_initialized(): + species_: List[int] = species[0].tolist() # Explicit type casting for TorchScript + self.holder = Holder(self.numSpecies, self.Rcr, self.Rca, + self.EtaR, self.ShfR, + self.EtaA, self.Zeta, self.ShfA, self.ShfZ, + species_, positions) + assert self.holder.is_initialized() + + radial, angular = operation(self.holder, positions[0], cell) features = torch.cat((radial, angular), dim=1).unsqueeze(0) return SpeciesAEV(species, features) \ No newline at end of file diff --git a/pytorch/environment.yml b/pytorch/environment.yml new file mode 100644 index 0000000..2cb0894 --- /dev/null +++ b/pytorch/environment.yml @@ -0,0 +1,12 @@ +name: nnpops +channels: + - conda-forge +dependencies: + - cmake + - gxx_linux-64 + - make + - mdtraj + - torchani 2.2 + - pytest + - python 3.9 + - pytorch 1.8.0 \ No newline at end of file