Skip to content

Commit

Permalink
Pytorch wrapper for the backward pass (not yet working)
Browse files Browse the repository at this point in the history
  • Loading branch information
Raimondas Galvelis committed Sep 28, 2020
1 parent d13eaff commit df6f3bc
Showing 1 changed file with 69 additions and 28 deletions.
97 changes: 69 additions & 28 deletions pytorch/SymmetryFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,71 @@
#include <torch/script.h>
#include "CpuANISymmetryFunctions.h"

class GradANISymmetryFunction : public torch::autograd::Function<GradANISymmetryFunction> {

public:
static torch::autograd::tensor_list forward(torch::autograd::AutogradContext *ctx,
int64_t numSpecies,
double Rcr,
double Rca,
const std::vector<double>& EtaR,
const std::vector<double>& ShfR,
const std::vector<double>& EtaA,
const std::vector<double>& Zeta,
const std::vector<double>& ShfA,
const std::vector<double>& ShfZ,
const std::vector<int64_t>& atomSpecies_,
const torch::Tensor& positions_) {

const int numAtoms = atomSpecies_.size();
const std::vector<int> atomSpecies(atomSpecies_.begin(), atomSpecies_.end());

std::vector<RadialFunction> radialFunctions;
for (const float eta: EtaR)
for (const float rs: ShfR)
radialFunctions.push_back({eta, rs});

std::vector<AngularFunction> angularFunctions;
for (const float eta: EtaA)
for (const float zeta: Zeta)
for (const float rs: ShfA)
for (const float thetas: ShfZ)
angularFunctions.push_back({eta, rs, zeta, thetas});

CpuANISymmetryFunctions sf(numAtoms, numSpecies, Rcr, Rca, false, atomSpecies, radialFunctions, angularFunctions, true);

const auto positions = positions_.toType(torch::kFloat);
auto radial = torch::empty({numAtoms, numSpecies * (int)radialFunctions.size()}, torch::kFloat);
auto angular = torch::empty({numAtoms, numSpecies * (numSpecies + 1) / 2 * (int)angularFunctions.size()}, torch::kFloat);

sf.computeSymmetryFunctions(positions.data_ptr<float>(), nullptr, radial.data_ptr<float>(), angular.data_ptr<float>());

return {radial, angular};
};

static torch::autograd::tensor_list backward(torch::autograd::AutogradContext *ctx, torch::autograd::tensor_list grads) {

const auto& radialGrad = grads[0];
const auto& angularGrad = grads[1];

// compute the gradients

torch::Tensor positionsGrad = torch::Tensor();

return { torch::Tensor(), // numSpecies
torch::Tensor(), // Rcr
torch::Tensor(), // Rca
torch::Tensor(), // EtaR
torch::Tensor(), // ShfR
torch::Tensor(), // EtaA
torch::Tensor(), // Zeta
torch::Tensor(), // ShfA
torch::Tensor(), // ShfZ
torch::Tensor(), // atomSpecies
positionsGrad }; // positions
};
};

static torch::autograd::tensor_list ANISymmetryFunction(int64_t numSpecies,
double Rcr,
double Rca,
Expand All @@ -33,35 +98,11 @@ static torch::autograd::tensor_list ANISymmetryFunction(int64_t numSpecies,
const std::vector<double>& Zeta,
const std::vector<double>& ShfA,
const std::vector<double>& ShfZ,
const std::vector<int64_t>& atomSpecies_,
const torch::Tensor& positions_) {

const int numAtoms = atomSpecies_.size();
const std::vector<int> atomSpecies(atomSpecies_.begin(), atomSpecies_.end());

std::vector<RadialFunction> radialFunctions;
for (const float eta: EtaR)
for (const float rs: ShfR)
radialFunctions.push_back({eta, rs});

std::vector<AngularFunction> angularFunctions;
for (const float eta: EtaA)
for (const float zeta: Zeta)
for (const float rs: ShfA)
for (const float thetas: ShfZ)
angularFunctions.push_back({eta, rs, zeta, thetas});

CpuANISymmetryFunctions sf(numAtoms, numSpecies, Rcr, Rca, false, atomSpecies, radialFunctions, angularFunctions, true);

const auto positions = positions_.toType(torch::kFloat);
auto radial = torch::empty({numAtoms, numSpecies * (int)radialFunctions.size()}, torch::kFloat);
auto angular = torch::empty({numAtoms, numSpecies * (numSpecies + 1) / 2 * (int)angularFunctions.size()}, torch::kFloat);

sf.computeSymmetryFunctions(positions.data_ptr<float>(), nullptr, radial.data_ptr<float>(), angular.data_ptr<float>());

return {radial, angular};
const std::vector<int64_t>& atomSpecies,
const torch::Tensor& positions) {
return GradANISymmetryFunction::apply(numSpecies, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, atomSpecies, positions);
}

TORCH_LIBRARY(NNPOps, m) {
m.def("ANISymmetryFunction", ANISymmetryFunction);
}
}

0 comments on commit df6f3bc

Please sign in to comment.