diff --git a/pytorch/SymmetryFunctions.cpp b/pytorch/SymmetryFunctions.cpp index 285ed05..372d338 100644 --- a/pytorch/SymmetryFunctions.cpp +++ b/pytorch/SymmetryFunctions.cpp @@ -24,6 +24,71 @@ #include #include "CpuANISymmetryFunctions.h" +class GradANISymmetryFunction : public torch::autograd::Function { + +public: + static torch::autograd::tensor_list forward(torch::autograd::AutogradContext *ctx, + int64_t numSpecies, + double Rcr, + double Rca, + const std::vector& EtaR, + const std::vector& ShfR, + const std::vector& EtaA, + const std::vector& Zeta, + const std::vector& ShfA, + const std::vector& ShfZ, + const std::vector& atomSpecies_, + const torch::Tensor& positions_) { + + const int numAtoms = atomSpecies_.size(); + const std::vector atomSpecies(atomSpecies_.begin(), atomSpecies_.end()); + + std::vector radialFunctions; + for (const float eta: EtaR) + for (const float rs: ShfR) + radialFunctions.push_back({eta, rs}); + + std::vector angularFunctions; + for (const float eta: EtaA) + for (const float zeta: Zeta) + for (const float rs: ShfA) + for (const float thetas: ShfZ) + angularFunctions.push_back({eta, rs, zeta, thetas}); + + CpuANISymmetryFunctions sf(numAtoms, numSpecies, Rcr, Rca, false, atomSpecies, radialFunctions, angularFunctions, true); + + const auto positions = positions_.toType(torch::kFloat); + auto radial = torch::empty({numAtoms, numSpecies * (int)radialFunctions.size()}, torch::kFloat); + auto angular = torch::empty({numAtoms, numSpecies * (numSpecies + 1) / 2 * (int)angularFunctions.size()}, torch::kFloat); + + sf.computeSymmetryFunctions(positions.data_ptr(), nullptr, radial.data_ptr(), angular.data_ptr()); + + return {radial, angular}; + }; + + static torch::autograd::tensor_list backward(torch::autograd::AutogradContext *ctx, torch::autograd::tensor_list grads) { + + const auto& radialGrad = grads[0]; + const auto& angularGrad = grads[1]; + + // compute the gradients + + torch::Tensor positionsGrad = torch::Tensor(); + + return { torch::Tensor(), // numSpecies + torch::Tensor(), // Rcr + torch::Tensor(), // Rca + torch::Tensor(), // EtaR + torch::Tensor(), // ShfR + torch::Tensor(), // EtaA + torch::Tensor(), // Zeta + torch::Tensor(), // ShfA + torch::Tensor(), // ShfZ + torch::Tensor(), // atomSpecies + positionsGrad }; // positions + }; +}; + static torch::autograd::tensor_list ANISymmetryFunction(int64_t numSpecies, double Rcr, double Rca, @@ -33,35 +98,11 @@ static torch::autograd::tensor_list ANISymmetryFunction(int64_t numSpecies, const std::vector& Zeta, const std::vector& ShfA, const std::vector& ShfZ, - const std::vector& atomSpecies_, - const torch::Tensor& positions_) { - - const int numAtoms = atomSpecies_.size(); - const std::vector atomSpecies(atomSpecies_.begin(), atomSpecies_.end()); - - std::vector radialFunctions; - for (const float eta: EtaR) - for (const float rs: ShfR) - radialFunctions.push_back({eta, rs}); - - std::vector angularFunctions; - for (const float eta: EtaA) - for (const float zeta: Zeta) - for (const float rs: ShfA) - for (const float thetas: ShfZ) - angularFunctions.push_back({eta, rs, zeta, thetas}); - - CpuANISymmetryFunctions sf(numAtoms, numSpecies, Rcr, Rca, false, atomSpecies, radialFunctions, angularFunctions, true); - - const auto positions = positions_.toType(torch::kFloat); - auto radial = torch::empty({numAtoms, numSpecies * (int)radialFunctions.size()}, torch::kFloat); - auto angular = torch::empty({numAtoms, numSpecies * (numSpecies + 1) / 2 * (int)angularFunctions.size()}, torch::kFloat); - - sf.computeSymmetryFunctions(positions.data_ptr(), nullptr, radial.data_ptr(), angular.data_ptr()); - - return {radial, angular}; + const std::vector& atomSpecies, + const torch::Tensor& positions) { + return GradANISymmetryFunction::apply(numSpecies, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, atomSpecies, positions); } TORCH_LIBRARY(NNPOps, m) { m.def("ANISymmetryFunction", ANISymmetryFunction); -} +} \ No newline at end of file