Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch wrapper #5

Merged
merged 33 commits into from
Oct 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
b74a2b8
PyTorch wrapper for the forward pass on CPU
Sep 25, 2020
a8eb8e0
CMake file for the PyTorch wrapper
Sep 25, 2020
bae1ecb
Pytorch wrapper for the backward pass (not yet working)
Sep 28, 2020
21b88b8
Wrap CpuANISymmetryFunctions as a custom Pytorch class
Sep 28, 2020
e7cf48c
Pytorch wrapper of the backward pass
Sep 28, 2020
3fdc9f0
Simplify Pytorch wrapper
Sep 29, 2020
7f65603
Pytorch wrapper for the CUDA implementation
Sep 29, 2020
a9a1fbd
Fix a typo
Sep 29, 2020
45a6031
Simplfy the Pytorch wrapper
Sep 30, 2020
46ef01e
Fix the memory leak in the PyTorch wrapper
Oct 1, 2020
62206eb
Pass the box vector to the PyTorch wrapper
Oct 1, 2020
0ea2863
Merge branch 'master' into pytorch
Oct 5, 2020
d94936a
Unify the names of PyTorch wrapper
Oct 6, 2020
524982a
Implement integration with TorchANI via the PyTorch wrapper
Oct 6, 2020
b2b2a9e
Simplify and add check to the TorchANI integration
Oct 6, 2020
4e6f226
Merge branch 'master' into pytorch
Oct 7, 2020
b3c6ca4
Rename the PyTorch wrapper component
Oct 7, 2020
fec2500
Add a test for TorchANISymmetryFunctions
Oct 7, 2020
060701b
Fix the serialization of TorchANISymmetryFunctions
Oct 7, 2020
5cd9ab9
Add a test for the serialization of TorchANISymmetryFunctions
Oct 7, 2020
f8b4584
Merge remote-tracking branch 'origin/master' into pytorch
Oct 15, 2020
551591b
Add more molecules for TorchANISymmetryFunctions tests
Oct 21, 2020
bdae88f
Update TorchANISymmetryFunctions tests to use all the molecules
Oct 21, 2020
ee82560
Improve CMake file for NNPOpsPyTorch
Oct 21, 2020
40bb2fa
Add installation instructions for NNPOpsPyTorch
Oct 21, 2020
df3b1ee
Fix the import of NNPOps in Python
Oct 21, 2020
525153f
Add an usage example for NNPOpsPyTorch
Oct 21, 2020
4d9fdc9
Fix the import in the example
Oct 21, 2020
ae73034
Add docstrings for TorchANISymmetryFunctions
Oct 22, 2020
c5cf004
Add more general text about the wrapper
Oct 22, 2020
cd86134
Fix typo
Oct 23, 2020
ddd222c
Add a benchmark script for TorchANISymmetryFunctions
Oct 28, 2020
ed3cc15
Make PyTorch and NNPOps to run on the same GPU
Oct 28, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions pytorch/BenchmarkTorchANISymmetryFunctions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import mdtraj
import time
import torch
import torchani

from NNPOps.SymmetryFunctions import TorchANISymmetryFunctions

device = torch.device('cuda')

mol = mdtraj.load('molecules/2iuz_ligand.mol2')
species = torch.tensor([[atom.element.atomic_number for atom in mol.top.atoms]], device=device)
positions = torch.tensor(mol.xyz, dtype=torch.float32, requires_grad=True, device=device)

nnp = torchani.models.ANI2x(periodic_table_index=True, model_index=None).to(device)
speciesPositions = nnp.species_converter((species, positions))
symmFuncRef = nnp.aev_computer
symmFunc = TorchANISymmetryFunctions(nnp.aev_computer).to(device)

aev_ref = symmFuncRef(speciesPositions).aevs
sum_aev_ref = torch.sum(aev_ref)
sum_aev_ref.backward()
grad_ref = positions.grad.clone()

N = 10000
start = time.time()
for _ in range(N):
aev_ref = symmFuncRef(speciesPositions).aevs
sum_aev_ref = torch.sum(aev_ref)
positions.grad.zero_()
sum_aev_ref.backward()
delta = time.time() - start
grad_ref = positions.grad.clone()
print('Original TorchANI symmetry functions')
print(f' Duration: {delta} s')
print(f' Speed: {delta/N*1000} ms/it')

aev = symmFunc(speciesPositions).aevs
sum_aev = torch.sum(aev)
positions.grad.zero_()
sum_aev.backward()
grad = positions.grad.clone()

N = 40000
start = time.time()
for _ in range(N):
aev = symmFunc(speciesPositions).aevs
sum_aev = torch.sum(aev)
positions.grad.zero_()
sum_aev.backward()
delta = time.time() - start
grad = positions.grad.clone()
print('Optimized TorchANI symmetry functions')
print(f' Duration: {delta} s')
print(f' Speed: {delta/N*1000} ms/it')

aev_error = torch.max(torch.abs(aev - aev_ref))
grad_error = torch.max(torch.abs(grad - grad_ref))
print(aev_error)
print(grad_error)
assert aev_error < 0.0002
assert grad_error < 0.007
22 changes: 22 additions & 0 deletions pytorch/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)

set(NAME NNPOps)
set(LIBRARY ${NAME}PyTorch)
project(${NAME} LANGUAGES CXX CUDA)

find_package(Python REQUIRED)
find_package(PythonLibs REQUIRED)
find_package(Torch REQUIRED)

set(CMAKE_INSTALL_RPATH_USE_LINK_PATH true)

add_library(${LIBRARY} SHARED SymmetryFunctions.cpp
../ani/CpuANISymmetryFunctions.cpp
../ani/CudaANISymmetryFunctions.cu)
target_compile_features(${LIBRARY} PRIVATE cxx_std_14)
target_include_directories(${LIBRARY} PRIVATE ${PYTHON_INCLUDE_DIRS})
target_include_directories(${LIBRARY} PRIVATE ../ani)
target_link_libraries(${LIBRARY} ${TORCH_LIBRARIES} ${PYTHON_LIBRARIES})

install(TARGETS ${LIBRARY} DESTINATION ${Python_SITEARCH}/${NAME})
install(FILES SymmetryFunctions.py DESTINATION ${Python_SITEARCH}/${NAME})
82 changes: 82 additions & 0 deletions pytorch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# PyTorch wrapper for NNPOps

*NNPOps* functionalities are available in *PyTorch* (https://pytorch.org/).

## Optimized TorchANI symmetry functions

Optimized drop-in replacement for `torchani.AEVComputer` (https://aiqm.github.io/torchani/api.html?highlight=speciesaev#torchani.AEVComputer)

### Example

```python
import mdtraj
import torch
import torchani

from NNPOps.SymmetryFunctions import TorchANISymmetryFunctions

device = torch.device('cuda')

# Load a molecule
molecule = mdtraj.load('molecule.mol2')
species = torch.tensor([[atom.element.atomic_number for atom in molecule.top.atoms]], device=device)
positions = torch.tensor(molecule.xyz, dtype=torch.float32, requires_grad=True, device=device)

# Construct ANI-2x and replace its native featurizer with NNPOps implementation
nnp = torchani.models.ANI2x(periodic_table_index=True).to(device)
nnp.aev_computer = TorchANISymmetryFunctions(nnp.aev_computer)

# Compute energy
energy = nnp((species, positions)).energies
energy.backward()
forces = -positions.grad.clone()

print(energy, forces)
```

## Installation

### Prerequisites

- *Linux*
- Complete *CUDA Toolkit* (https://developer.nvidia.com/cuda-downloads)
- *Miniconda* (https://docs.conda.io/en/latest/miniconda.html#linux-installers)

### Build & install

- Crate a *Conda* environment
```bash
$ conda create -n nnpops \
-c pytorch \
-c conda-forge \
cmake \
git \
gxx_linux-64 \
make \
mdtraj \
pytest \
python=3.8 \
pytorch=1.6 \
torchani=2.2
$ conda activate nnpops
```
- Get the source code
```bash
$ git clone https://github.com/peastman/NNPOps.git
```
- Configure, build, and install
```bash
$ mkdir build
$ cd build
$ cmake ../NNPOps/pytorch \
-DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc \
-DCMAKE_CUDA_HOST_COMPILER=$CXX \
-DTorch_DIR=$CONDA_PREFIX/lib/python3.8/site-packages/torch/share/cmake/Torch \
-DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX
$ make install
```
- Optional: run tests
```bash
$ cd ../NNPOps/pytorch
$ pytest TestSymmetryFunctions.py
```
191 changes: 191 additions & 0 deletions pytorch/SymmetryFunctions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
/**
* Copyright (c) 2020 Acellera
* Authors: Raimondas Galvelis
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/

#include <stdexcept>
#include <cuda_runtime.h>
#include <torch/script.h>
#include "CpuANISymmetryFunctions.h"
#include "CudaANISymmetryFunctions.h"

#define CHECK_CUDA_RESULT(result) \
if (result != cudaSuccess) { \
throw std::runtime_error(std::string("Encountered error ")+cudaGetErrorName(result)+" at "+__FILE__+":"+std::to_string(__LINE__));\
}

class CustomANISymmetryFunctions : public torch::CustomClassHolder {
public:
CustomANISymmetryFunctions(int64_t numSpecies_,
double Rcr,
double Rca,
const std::vector<double>& EtaR,
const std::vector<double>& ShfR,
const std::vector<double>& EtaA,
const std::vector<double>& Zeta,
const std::vector<double>& ShfA,
const std::vector<double>& ShfZ,
const std::vector<int64_t>& atomSpecies_,
const torch::Tensor& positions) : torch::CustomClassHolder() {

tensorOptions = torch::TensorOptions().device(positions.device()); // Data type of float by default
int numAtoms = atomSpecies_.size();
int numSpecies = numSpecies_;
const std::vector<int> atomSpecies(atomSpecies_.begin(), atomSpecies_.end());

std::vector<RadialFunction> radialFunctions;
for (const float eta: EtaR)
for (const float rs: ShfR)
radialFunctions.push_back({eta, rs});

std::vector<AngularFunction> angularFunctions;
for (const float eta: EtaA)
for (const float zeta: Zeta)
for (const float rs: ShfA)
for (const float thetas: ShfZ)
angularFunctions.push_back({eta, rs, zeta, thetas});

const torch::Device& device = tensorOptions.device();
if (device.is_cpu())
symFunc = std::make_shared<CpuANISymmetryFunctions>(numAtoms, numSpecies, Rcr, Rca, false, atomSpecies, radialFunctions, angularFunctions, true);
if (device.is_cuda()) {
// PyTorch allow to chose GPU with "torch.device", but it doesn't set as the default one.
CHECK_CUDA_RESULT(cudaSetDevice(device.index()));
symFunc = std::make_shared<CudaANISymmetryFunctions>(numAtoms, numSpecies, Rcr, Rca, false, atomSpecies, radialFunctions, angularFunctions, true);
}

radial = torch::empty({numAtoms, numSpecies * (int)radialFunctions.size()}, tensorOptions);
angular = torch::empty({numAtoms, numSpecies * (numSpecies + 1) / 2 * (int)angularFunctions.size()}, tensorOptions);
positionsGrad = torch::empty({numAtoms, 3}, tensorOptions);
};

torch::autograd::tensor_list forward(const torch::Tensor& positions_, const torch::optional<torch::Tensor>& periodicBoxVectors_) {

const torch::Tensor positions = positions_.to(tensorOptions);

torch::Tensor periodicBoxVectors;
float* periodicBoxVectorsPtr = nullptr;
if (periodicBoxVectors_) {
periodicBoxVectors = periodicBoxVectors_->to(tensorOptions);
float* periodicBoxVectorsPtr = periodicBoxVectors.data_ptr<float>();
}

symFunc->computeSymmetryFunctions(positions.data_ptr<float>(), periodicBoxVectorsPtr, radial.data_ptr<float>(), angular.data_ptr<float>());

return {radial, angular};
};

torch::Tensor backward(const torch::autograd::tensor_list& grads) {

const torch::Tensor radialGrad = grads[0].clone();
const torch::Tensor angularGrad = grads[1].clone();

symFunc->backprop(radialGrad.data_ptr<float>(), angularGrad.data_ptr<float>(), positionsGrad.data_ptr<float>());

return positionsGrad;
}

private:
torch::TensorOptions tensorOptions;
std::shared_ptr<ANISymmetryFunctions> symFunc;
torch::Tensor radial;
torch::Tensor angular;
torch::Tensor positionsGrad;
};

class GradANISymmetryFunction : public torch::autograd::Function<GradANISymmetryFunction> {

public:
static torch::autograd::tensor_list forward(torch::autograd::AutogradContext *ctx,
int64_t numSpecies,
double Rcr,
double Rca,
const std::vector<double>& EtaR,
const std::vector<double>& ShfR,
const std::vector<double>& EtaA,
const std::vector<double>& Zeta,
const std::vector<double>& ShfA,
const std::vector<double>& ShfZ,
const std::vector<int64_t>& atomSpecies,
const torch::Tensor& positions,
const torch::optional<torch::Tensor>& periodicBoxVectors) {

const auto symFunc = torch::intrusive_ptr<CustomANISymmetryFunctions>::make(
numSpecies, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, atomSpecies, positions);
ctx->saved_data["symFunc"] = symFunc;

return symFunc->forward(positions, periodicBoxVectors);
};

static torch::autograd::tensor_list backward(torch::autograd::AutogradContext *ctx, const torch::autograd::tensor_list& grads) {

const auto symFunc = ctx->saved_data["symFunc"].toCustomClass<CustomANISymmetryFunctions>();
torch::Tensor positionsGrad = symFunc->backward(grads);
ctx->saved_data.erase("symFunc");

return { torch::Tensor(), // numSpecies
torch::Tensor(), // Rcr
torch::Tensor(), // Rca
torch::Tensor(), // EtaR
torch::Tensor(), // ShfR
torch::Tensor(), // EtaA
torch::Tensor(), // Zeta
torch::Tensor(), // ShfA
torch::Tensor(), // ShfZ
torch::Tensor(), // atomSpecies
positionsGrad, // positions
torch::Tensor()}; // periodicBoxVectors
};
};

static torch::autograd::tensor_list ANISymmetryFunctionsOp(int64_t numSpecies,
double Rcr,
double Rca,
const std::vector<double>& EtaR,
const std::vector<double>& ShfR,
const std::vector<double>& EtaA,
const std::vector<double>& Zeta,
const std::vector<double>& ShfA,
const std::vector<double>& ShfZ,
const std::vector<int64_t>& atomSpecies,
const torch::Tensor& positions,
const torch::optional<torch::Tensor>& periodicBoxVectors) {

return GradANISymmetryFunction::apply(numSpecies, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, atomSpecies, positions, periodicBoxVectors);
}

TORCH_LIBRARY(NNPOps, m) {
m.class_<CustomANISymmetryFunctions>("CustomANISymmetryFunctions")
.def(torch::init<int64_t, // numSpecies
double, // Rcr
double, // Rca
const std::vector<double>&, // EtaR
const std::vector<double>&, // ShfR
const std::vector<double>&, // EtaA
const std::vector<double>&, // Zeta
const std::vector<double>&, // ShfA
const std::vector<double>&, // ShfZ
const std::vector<int64_t>&, // atomSpecies
const torch::Tensor&>()) // positions
.def("forward", &CustomANISymmetryFunctions::forward)
.def("backward", &CustomANISymmetryFunctions::backward);
m.def("ANISymmetryFunctions", ANISymmetryFunctionsOp);
}
Loading