Skip to content

Commit

Permalink
batched_dense_vec_jagged_2d_mul add Meta and Autograd backend
Browse files Browse the repository at this point in the history
Summary:
Add Meta and Autograd backend for batched_dense_vec_jagged_2d_mul for dynamo and AOT autograd.

A more proper way of doing Autograd to make inductor working for jagged ops. Also added Meta tensors so we're not mercy of arbitrary zero inputs fed into by default.

Differential Revision: D41387994

fbshipit-source-id: 3dcfe64084b4c65a75c942b3fb531c5d654abe24
  • Loading branch information
brad-mengchi authored and facebook-github-bot committed Nov 18, 2022
1 parent a8c703f commit 2e242ab
Show file tree
Hide file tree
Showing 5 changed files with 327 additions and 179 deletions.
11 changes: 11 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,17 @@ std::vector<at::Tensor> stacked_jagged_1d_to_dense_cpu(
int64_t padding_value);
#endif

at::Tensor batched_dense_vec_jagged_2d_mul_forward(
const at::Tensor& v,
const at::Tensor& a_values,
const at::Tensor& a_offsets);

std::tuple<at::Tensor, at::Tensor> batched_dense_vec_jagged_2d_mul_backward(
const at::Tensor& grad_output,
const at::Tensor& v,
const at::Tensor& a_values,
const at::Tensor& a_offsets);

///@ingroup sparse-data-cpu
/// Divide the prediction range (e.g., [0, 1]) into B bins. In each bin, use
/// two parameters to store the number of positive examples and the number of
Expand Down
251 changes: 137 additions & 114 deletions fbgemm_gpu/src/jagged_tensor_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1873,6 +1873,131 @@ __global__ __launch_bounds__(kMaxThreads) void outer_prod_jagged_2d_output(
}
}
Tensor batched_dense_vec_jagged_2d_mul_forward(
const Tensor& v,
const Tensor& a_values,
const Tensor& a_offsets) {
TENSOR_ON_CUDA_GPU(v);
TENSOR_ON_CUDA_GPU(a_values);
TENSOR_ON_CUDA_GPU(a_offsets);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(v.get_device());
const int B = a_offsets.numel() - 1;
TORCH_CHECK(
B == 0 || v.size(0) % B == 0,
"B, ",
B,
" doesn't divide v.size(0), ",
v.size(0));
const int H = (B == 0) ? 1 : v.size(0) / B;
const int D = a_values.size(-1) / H;
auto output = at::empty({B * H, D}, v.options());
if (B > 0 && D > 0) {
const int block_dim_x =
std::min(div_round_up(D, kWarpSize) * kWarpSize, kMaxThreads);
const int block_dim_y = kMaxThreads / block_dim_x;
AT_DISPATCH_INDEX_TYPES(
a_offsets.scalar_type(), "dense_vec_jagged_2d_bmm_kernel_1", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
a_values.scalar_type(),
"dense_vec_jagged_2d_bmm_kernel_2",
[&] {
dense_vec_jagged_2d_bmm<index_t, scalar_t>
<<<div_round_up(B * H, block_dim_y),
dim3(block_dim_x, block_dim_y),
0,
at::cuda::getCurrentCUDAStream()>>>(
v.packed_accessor32<scalar_t, 2>(),
a_values.packed_accessor32<scalar_t, 2>(),
a_offsets.packed_accessor32<index_t, 1>(),
output.packed_accessor32<scalar_t, 2>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
}
return output;
}
std::tuple<Tensor, Tensor> batched_dense_vec_jagged_2d_mul_backward(
const Tensor& grad_output,
const Tensor& v,
const Tensor& a_values,
const Tensor& a_offsets) {
TENSOR_ON_CUDA_GPU(grad_output);
TENSOR_ON_CUDA_GPU(a_values);
TENSOR_ON_CUDA_GPU(a_offsets);
TENSOR_ON_CUDA_GPU(v);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad_output.get_device());
const int B = a_offsets.numel() - 1;
const int D = grad_output.size(-1);
Tensor a_values_grad = at::zeros_like(a_values);
Tensor v_grad = at::empty_like(v);
if (B > 0 && D > 0) {
TORCH_CHECK(
v.size(0) % B == 0, "B, ", B, " doesn't divide v.size(0), ", v.size(0));
const int H = v.size(0) / B;
const int max_L = v.size(-1);
AT_DISPATCH_INDEX_TYPES(
a_offsets.scalar_type(),
"dense_vec_jagged_2d_bmm_backward_kernel_1",
[&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
grad_output.scalar_type(),
"dense_vec_jagged_2d_bmm_backward_kernel_2",
[&] {
int block_dim_x = std::min(
div_round_up(max_L, kWarpSize) * kWarpSize, kMaxThreads);
int block_dim_y = kMaxThreads / block_dim_x;
dense_vec_jagged_2d_transposed_bmm<index_t, scalar_t>
<<<div_round_up(B * H, block_dim_y),
dim3(block_dim_x, block_dim_y),
0,
at::cuda::getCurrentCUDAStream()>>>(
grad_output.packed_accessor32<scalar_t, 2>(),
a_values.packed_accessor32<scalar_t, 2>(),
a_offsets.packed_accessor32<index_t, 1>(),
v_grad.packed_accessor32<scalar_t, 2>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
block_dim_x = std::min(
div_round_up(D, kWarpSize) * kWarpSize, kMaxThreads);
block_dim_y = kMaxThreads / block_dim_x;
outer_prod_jagged_2d_output<index_t, scalar_t>
<<<div_round_up(B * H * max_L, block_dim_y),
dim3(block_dim_x, block_dim_y),
0,
at::cuda::getCurrentCUDAStream()>>>(
v.packed_accessor32<scalar_t, 2>(),
grad_output.packed_accessor32<scalar_t, 2>(),
a_offsets.packed_accessor32<index_t, 1>(),
a_values_grad.packed_accessor32<scalar_t, 2>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
} else {
v_grad.zero_();
}
return {v_grad, a_values_grad};
}
// batched dense vector x jagged 2D tensor multiplication
// dense vector [B H, N]
// jagged tensor [B, N, H D] where N is jagged
Expand All @@ -1886,50 +2011,8 @@ class BatchedDenseVecJagged2DMulGPUOp
const Tensor& a_offsets) {
ctx->save_for_backward({v, a_values, a_offsets});
TENSOR_ON_CUDA_GPU(v);
TENSOR_ON_CUDA_GPU(a_values);
TENSOR_ON_CUDA_GPU(a_offsets);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(v.get_device());
const int B = a_offsets.numel() - 1;
TORCH_CHECK(
B == 0 || v.size(0) % B == 0,
"B, ",
B,
" doesn't divide v.size(0), ",
v.size(0));
const int H = (B == 0) ? 1 : v.size(0) / B;
const int D = a_values.size(-1) / H;
auto output = at::empty({B * H, D}, v.options());
if (B > 0 && D > 0) {
const int block_dim_x =
std::min(div_round_up(D, kWarpSize) * kWarpSize, kMaxThreads);
const int block_dim_y = kMaxThreads / block_dim_x;
AT_DISPATCH_INDEX_TYPES(
a_offsets.scalar_type(), "dense_vec_jagged_2d_bmm_kernel_1", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
a_values.scalar_type(),
"dense_vec_jagged_2d_bmm_kernel_2",
[&] {
dense_vec_jagged_2d_bmm<index_t, scalar_t>
<<<div_round_up(B * H, block_dim_y),
dim3(block_dim_x, block_dim_y),
0,
at::cuda::getCurrentCUDAStream()>>>(
v.packed_accessor32<scalar_t, 2>(),
a_values.packed_accessor32<scalar_t, 2>(),
a_offsets.packed_accessor32<index_t, 1>(),
output.packed_accessor32<scalar_t, 2>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
}
auto output =
batched_dense_vec_jagged_2d_mul_forward(v, a_values, a_offsets);
return {output};
}
Expand All @@ -1944,78 +2027,12 @@ class BatchedDenseVecJagged2DMulGPUOp
const Tensor a_offsets = *savedItr++;
TORCH_CHECK(grad_outputs.size() == 1);
TENSOR_ON_CUDA_GPU(grad_outputs[0]);
TENSOR_ON_CUDA_GPU(a_values);
TENSOR_ON_CUDA_GPU(a_offsets);
TENSOR_ON_CUDA_GPU(v);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad_outputs[0].get_device());
const int B = a_offsets.numel() - 1;
const int D = grad_outputs[0].size(-1);
Tensor a_values_grad = at::zeros_like(a_values);
Tensor v_grad = at::empty_like(v);
if (B > 0 && D > 0) {
TORCH_CHECK(
v.size(0) % B == 0,
"B, ",
B,
" doesn't divide v.size(0), ",
v.size(0));
const int H = v.size(0) / B;
const int max_L = v.size(-1);
AT_DISPATCH_INDEX_TYPES(
a_offsets.scalar_type(),
"dense_vec_jagged_2d_bmm_baackward_kernel_1",
[&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
grad_outputs[0].scalar_type(),
"dense_vec_jagged_2d_bmm_baackward_kernel_2",
[&] {
int block_dim_x = std::min(
div_round_up(max_L, kWarpSize) * kWarpSize, kMaxThreads);
int block_dim_y = kMaxThreads / block_dim_x;
dense_vec_jagged_2d_transposed_bmm<index_t, scalar_t>
<<<div_round_up(B * H, block_dim_y),
dim3(block_dim_x, block_dim_y),
0,
at::cuda::getCurrentCUDAStream()>>>(
grad_outputs[0].packed_accessor32<scalar_t, 2>(),
a_values.packed_accessor32<scalar_t, 2>(),
a_offsets.packed_accessor32<index_t, 1>(),
v_grad.packed_accessor32<scalar_t, 2>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
block_dim_x = std::min(
div_round_up(D, kWarpSize) * kWarpSize, kMaxThreads);
block_dim_y = kMaxThreads / block_dim_x;
outer_prod_jagged_2d_output<index_t, scalar_t>
<<<div_round_up(B * H * max_L, block_dim_y),
dim3(block_dim_x, block_dim_y),
0,
at::cuda::getCurrentCUDAStream()>>>(
v.packed_accessor32<scalar_t, 2>(),
grad_outputs[0].packed_accessor32<scalar_t, 2>(),
a_offsets.packed_accessor32<index_t, 1>(),
a_values_grad.packed_accessor32<scalar_t, 2>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
} else {
v_grad.zero_();
}
auto outputs = batched_dense_vec_jagged_2d_mul_backward(
grad_outputs[0], v, a_values, a_offsets);
return {
v_grad,
a_values_grad,
std::get<0>(outputs),
std::get<1>(outputs),
torch::autograd::Variable(), // a_offsets
};
}
Expand Down Expand Up @@ -3038,6 +3055,12 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
DISPATCH_TO_CUDA(
"batched_dense_vec_jagged_2d_mul",
fbgemm_gpu::batched_dense_vec_jagged_2d_mul);
DISPATCH_TO_CUDA(
"batched_dense_vec_jagged_2d_mul_forward",
fbgemm_gpu::batched_dense_vec_jagged_2d_mul_forward);
DISPATCH_TO_CUDA(
"batched_dense_vec_jagged_2d_mul_backward",
fbgemm_gpu::batched_dense_vec_jagged_2d_mul_backward);
DISPATCH_TO_CUDA(
"jagged_index_select", fbgemm_gpu::jagged_index_select_2d_gpu);
DISPATCH_TO_CUDA("jagged_1d_to_dense", fbgemm_gpu::jagged_1d_to_dense_gpu);
Expand Down
55 changes: 55 additions & 0 deletions fbgemm_gpu/src/jagged_tensor_ops_autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,51 @@ class JaggedDenseMulAutogradOp
}
};

class BatchedDenseVecJagged2DMulAutogradOp
: public torch::autograd::Function<BatchedDenseVecJagged2DMulAutogradOp> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const Tensor& v,
const Tensor& a_values,
const Tensor& a_offsets) {
ctx->save_for_backward({v, a_values, a_offsets});

static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow(
"fbgemm::batched_dense_vec_jagged_2d_mul_forward", "")
.typed<decltype(batched_dense_vec_jagged_2d_mul_forward)>();
auto output = op.call(v, a_values, a_offsets);

return {output};
}

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_outputs) {
const auto saved = ctx->get_saved_variables();
auto savedItr = std::begin(saved);
const Tensor v = *savedItr++;
const Tensor a_values = *savedItr++;
const Tensor a_offsets = *savedItr++;
TORCH_CHECK(grad_outputs.size() == 1);

static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow(
"fbgemm::batched_dense_vec_jagged_2d_mul_backward", "")
.typed<decltype(batched_dense_vec_jagged_2d_mul_backward)>();
auto outputs = op.call(grad_outputs[0], v, a_values, a_offsets);

return {
std::get<0>(outputs),
std::get<1>(outputs),
torch::autograd::Variable(), // a_offsets
};
}
};

///@ingroup jagged-tensor-ops-autograd
Tensor jagged_to_padded_dense_autograd(
const Tensor& values,
Expand Down Expand Up @@ -225,6 +270,13 @@ std::tuple<Tensor, std::vector<Tensor>> jagged_dense_elementwise_mul_autograd(
return {prod_values, x_offsets};
}

Tensor batched_dense_vec_jagged_2d_mul_autograd(
const Tensor& v,
const Tensor& a_values,
const Tensor& a_offsets) {
return BatchedDenseVecJagged2DMulAutogradOp::apply(v, a_values, a_offsets)[0];
}

} // namespace fbgemm_gpu

TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) {
Expand All @@ -242,4 +294,7 @@ TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) {
m.impl(
"jagged_dense_elementwise_mul",
TORCH_FN(fbgemm_gpu::jagged_dense_elementwise_mul_autograd));
m.impl(
"batched_dense_vec_jagged_2d_mul",
TORCH_FN(fbgemm_gpu::batched_dense_vec_jagged_2d_mul_autograd));
}
Loading

0 comments on commit 2e242ab

Please sign in to comment.