forked from openmm/NNPOps
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Raimondas Galvelis
committed
Jun 7, 2022
1 parent
262261e
commit 1a44054
Showing
4 changed files
with
188 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import os | ||
import torch as pt | ||
from torch.utils import cpp_extension | ||
|
||
src_dir = os.path.dirname(__file__) | ||
sources = ['messages.cpp', 'messages_cpu.cpp'] + (['messages_cuda.cu'] if pt.cuda.is_available() else []) | ||
sources = [os.path.join(src_dir, name) for name in sources] | ||
|
||
cpp_extension.load(name='messages', sources=sources, is_python_module=False) | ||
pass_messages = pt.ops.messages.pass_messages |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#include <torch/extension.h> | ||
|
||
TORCH_LIBRARY(messages, m) { | ||
m.def("pass_messages(Tensor neighbors, Tensor messages, Tensor states) -> (Tensor messages)"); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
#include <torch/extension.h> | ||
|
||
using torch::kInt32; | ||
using torch::logical_and; | ||
using torch::Tensor; | ||
|
||
static Tensor forward(const Tensor& neighbors, const Tensor& messages, const Tensor& states) { | ||
|
||
TORCH_CHECK(neighbors.dim() == 2, "Expected \"neighbors\" to have two dimensions"); | ||
TORCH_CHECK(neighbors.size(0) == 2, "Expected the 2nd dimension size of \"neighbors\" to be 2"); | ||
TORCH_CHECK(neighbors.scalar_type() == kInt32, "Expected \"neighbors\" to have data type of int32"); | ||
TORCH_CHECK(neighbors.is_contiguous(), "Expected \"neighbors\" to be contiguous"); | ||
|
||
TORCH_CHECK(messages.dim() == 2, "Expected \"messages\" to have two dimensions"); | ||
TORCH_CHECK(messages.size(1) % 32 == 0, "Expected the 2nd dimension size of \"messages\" to be a multiple of 32"); | ||
TORCH_CHECK(messages.size(1) <= 1024, "Expected the 2nd dimension size of \"messages\" to be less than 1024"); | ||
TORCH_CHECK(messages.is_contiguous(), "Expected \"messages\" to be contiguous"); | ||
|
||
TORCH_CHECK(states.dim() == 2, "Expected \"states\" to have two dimensions"); | ||
TORCH_CHECK(states.size(1) == messages.size(1), "Expected the 2nd dimension size of \"messages\" and \"states\" to be the same"); | ||
TORCH_CHECK(states.scalar_type() == messages.scalar_type(), "Expected the data type of \"messages\" and \"states\" to be the same"); | ||
TORCH_CHECK(states.is_contiguous(), "Expected \"messages\" to be contiguous"); | ||
|
||
const Tensor rows = neighbors[0]; | ||
const Tensor columns = neighbors[1]; | ||
|
||
const int num_features = messages.size(1); | ||
|
||
const Tensor mask = logical_and(rows > -1, columns > -1); | ||
const Tensor masked_rows = rows.masked_select(mask).to(torch::kLong); | ||
const Tensor masked_columns = columns.masked_select(mask).to(torch::kLong); | ||
const Tensor masked_messages = messages.masked_select(mask.unsqueeze(1)).reshape({-1, num_features}); | ||
|
||
Tensor new_states = states.clone(); | ||
new_states.index_add_(0, masked_rows, masked_messages); | ||
new_states.index_add_(0, masked_columns, masked_messages); | ||
|
||
return new_states; | ||
} | ||
|
||
TORCH_LIBRARY_IMPL(messages, CPU, m) { | ||
m.impl("pass_messages", &forward); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
#include <c10/cuda/CUDAGuard.h> | ||
#include <c10/cuda/CUDAStream.h> | ||
#include <torch/extension.h> | ||
|
||
using c10::cuda::CUDAStreamGuard; | ||
using c10::cuda::getCurrentCUDAStream; | ||
using torch::autograd::AutogradContext; | ||
using torch::autograd::Function; | ||
using torch::autograd::tensor_list; | ||
using torch::kInt32; | ||
using torch::PackedTensorAccessor32; | ||
using torch::RestrictPtrTraits; | ||
using torch::Tensor; | ||
using torch::TensorOptions; | ||
|
||
template <typename scalar_t, int num_dims> | ||
using Accessor = PackedTensorAccessor32<scalar_t, num_dims, RestrictPtrTraits>; | ||
|
||
template <typename scalar_t, int num_dims> | ||
inline Accessor<scalar_t, num_dims> get_accessor(const Tensor& tensor) { | ||
return tensor.packed_accessor32<scalar_t, num_dims, RestrictPtrTraits>(); | ||
}; | ||
|
||
template <typename scalar_t> __global__ void kernel_forward( | ||
const Accessor<int32_t, 2> neighbors, | ||
const Accessor<scalar_t, 2> messages, | ||
Accessor<scalar_t, 2> new_states | ||
) { | ||
const int32_t i_neig = blockIdx.x; | ||
const int32_t i_dir = blockIdx.y; | ||
const int32_t i_atom = neighbors[i_dir][i_neig]; | ||
if (i_atom < 0) return; | ||
|
||
const int32_t i_feat = threadIdx.x; | ||
atomicAdd(&new_states[i_atom][i_feat], messages[i_neig][i_feat]); | ||
} | ||
|
||
template <typename scalar_t> __global__ void kernel_backward( | ||
const Accessor<int32_t, 2> neighbors, | ||
const Accessor<scalar_t, 2> grad_new_state, | ||
Accessor<scalar_t, 2> grad_messages | ||
) { | ||
const int32_t i_neig = blockIdx.x; | ||
const int32_t i_dir = blockIdx.y; | ||
const int32_t i_atom = neighbors[i_dir][i_neig]; | ||
if (i_atom < 0) return; | ||
|
||
const int32_t i_feat = threadIdx.x; | ||
atomicAdd(&grad_messages[i_neig][i_feat], grad_new_state[i_atom][i_feat]); | ||
} | ||
|
||
class Autograd : public Function<Autograd> { | ||
public: | ||
static tensor_list forward(AutogradContext* ctx, | ||
const Tensor& neighbors, | ||
const Tensor& messages, | ||
const Tensor& states) { | ||
|
||
TORCH_CHECK(neighbors.dim() == 2, "Expected \"neighbors\" to have two dimensions"); | ||
TORCH_CHECK(neighbors.size(0) == 2, "Expected the 2nd dimension size of \"neighbors\" to be 2"); | ||
TORCH_CHECK(neighbors.scalar_type() == kInt32, "Expected \"neighbors\" to have data type of int32"); | ||
TORCH_CHECK(neighbors.is_contiguous(), "Expected \"neighbors\" to be contiguous"); | ||
|
||
TORCH_CHECK(messages.dim() == 2, "Expected \"messages\" to have two dimensions"); | ||
TORCH_CHECK(messages.size(1) % 32 == 0, "Expected the 2nd dimension size of \"messages\" to be a multiple of 32"); | ||
TORCH_CHECK(messages.size(1) <= 1024, "Expected the 2nd dimension size of \"messages\" to be less than 1024"); | ||
TORCH_CHECK(messages.is_contiguous(), "Expected \"messages\" to be contiguous"); | ||
|
||
TORCH_CHECK(states.dim() == 2, "Expected \"states\" to have two dimensions"); | ||
TORCH_CHECK(states.size(1) == messages.size(1), "Expected the 2nd dimension size of \"messages\" and \"states\" to be the same"); | ||
TORCH_CHECK(states.scalar_type() == messages.scalar_type(), "Expected the data type of \"messages\" and \"states\" to be the same"); | ||
TORCH_CHECK(states.is_contiguous(), "Expected \"messages\" to be contiguous"); | ||
|
||
const int num_neighbors = neighbors.size(1); | ||
const int num_features = messages.size(1); | ||
|
||
const dim3 blocks(num_neighbors, 2); | ||
const dim3 threads(num_features); | ||
const auto stream = getCurrentCUDAStream(neighbors.get_device()); | ||
|
||
Tensor new_states = states.clone(); | ||
|
||
AT_DISPATCH_FLOATING_TYPES(messages.scalar_type(), "pass_messages_forward", [&]() { | ||
const CUDAStreamGuard guard(stream); | ||
kernel_forward<<<blocks, threads, 0, stream>>>( | ||
get_accessor<int32_t, 2>(neighbors), | ||
get_accessor<scalar_t, 2>(messages), | ||
get_accessor<scalar_t, 2>(new_states)); | ||
}); | ||
|
||
ctx->save_for_backward({neighbors}); | ||
|
||
return {new_states}; | ||
} | ||
|
||
static tensor_list backward(AutogradContext* ctx, tensor_list grad_inputs) { | ||
|
||
const Tensor neighbors = ctx->get_saved_variables()[0]; | ||
const Tensor grad_new_state = grad_inputs[0]; | ||
|
||
const int num_neighbors = neighbors.size(1); | ||
const int num_features = grad_new_state.size(1); | ||
|
||
const dim3 blocks(num_neighbors, 2); | ||
const dim3 threads(num_features); | ||
const auto stream = getCurrentCUDAStream(neighbors.get_device()); | ||
|
||
Tensor grad_messages = torch::zeros({num_neighbors, num_features}, grad_new_state.options()); | ||
|
||
AT_DISPATCH_FLOATING_TYPES(grad_new_state.scalar_type(), "pass_messages_backward", [&]() { | ||
const CUDAStreamGuard guard(stream); | ||
kernel_backward<<<blocks, threads, 0, stream>>>( | ||
get_accessor<int32_t, 2>(neighbors), | ||
get_accessor<scalar_t, 2>(grad_new_state), | ||
get_accessor<scalar_t, 2>(grad_messages)); | ||
}); | ||
|
||
return {Tensor(), // grad_neighbors | ||
grad_messages, | ||
grad_new_state.clone()}; // grad_state | ||
} | ||
}; | ||
|
||
TORCH_LIBRARY_IMPL(messages, AutogradCUDA, m) { | ||
m.impl("pass_messages", [](const Tensor& neighbors, | ||
const Tensor& messages, | ||
const Tensor& states) { | ||
return Autograd::apply(neighbors, messages, states)[0]; | ||
}); | ||
} |