Skip to content

Commit

Permalink
batched_dense_vec_jagged_2d_mul add Meta and Autograd backend (#1468)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1468

Add Meta and Autograd backend for batched_dense_vec_jagged_2d_mul for dynamo and AOT autograd.

A more proper way of doing Autograd to make inductor working for jagged ops. Also added Meta tensors so we're not mercy of arbitrary zero inputs fed into by default.

Reviewed By: xiaosun86, yf225, jianyuh

Differential Revision: D41387994

fbshipit-source-id: 9c6d925714a3477e9bdb7b0d041bda162a4ad61d
  • Loading branch information
brad-mengchi authored and facebook-github-bot committed Nov 29, 2022
1 parent e4f6ed8 commit e749c95
Show file tree
Hide file tree
Showing 6 changed files with 373 additions and 252 deletions.
5 changes: 5 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,11 @@ jagged_dense_dense_elementwise_add_jagged_output_autograd(
const at::Tensor& y_0,
const at::Tensor& y_1);

at::Tensor batched_dense_vec_jagged_2d_mul_autograd(
const at::Tensor& v,
const at::Tensor& a_values,
const at::Tensor& a_offsets);

std::tuple<at::Tensor, std::vector<at::Tensor>>
jagged_dense_elementwise_mul_autograd(
const at::Tensor& x_values,
Expand Down
259 changes: 117 additions & 142 deletions fbgemm_gpu/src/jagged_tensor_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1702,160 +1702,129 @@ __global__ __launch_bounds__(kMaxThreads) void outer_prod_jagged_2d_output(
}
}
// batched dense vector x jagged 2D tensor multiplication
// dense vector [B H, N]
// jagged tensor [B, N, H D] where N is jagged
class BatchedDenseVecJagged2DMulGPUOp
: public torch::autograd::Function<BatchedDenseVecJagged2DMulGPUOp> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const Tensor& v,
const Tensor& a_values,
const Tensor& a_offsets) {
ctx->save_for_backward({v, a_values, a_offsets});
TENSOR_ON_CUDA_GPU(v);
TENSOR_ON_CUDA_GPU(a_values);
TENSOR_ON_CUDA_GPU(a_offsets);
Tensor batched_dense_vec_jagged_2d_mul_forward(
const Tensor& v,
const Tensor& a_values,
const Tensor& a_offsets) {
TENSOR_ON_CUDA_GPU(v);
TENSOR_ON_CUDA_GPU(a_values);
TENSOR_ON_CUDA_GPU(a_offsets);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(v.get_device());
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(v.get_device());
const int B = a_offsets.numel() - 1;
TORCH_CHECK(
B == 0 || v.size(0) % B == 0,
"B, ",
B,
" doesn't divide v.size(0), ",
v.size(0));
const int H = (B == 0) ? 1 : v.size(0) / B;
const int D = a_values.size(-1) / H;
auto output = at::empty({B * H, D}, v.options());
if (B > 0 && D > 0) {
const int block_dim_x =
std::min(div_round_up(D, kWarpSize) * kWarpSize, kMaxThreads);
const int block_dim_y = kMaxThreads / block_dim_x;
AT_DISPATCH_INDEX_TYPES(
a_offsets.scalar_type(), "dense_vec_jagged_2d_bmm_kernel_1", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
a_values.scalar_type(),
"dense_vec_jagged_2d_bmm_kernel_2",
[&] {
dense_vec_jagged_2d_bmm<index_t, scalar_t>
<<<div_round_up(B * H, block_dim_y),
dim3(block_dim_x, block_dim_y),
0,
at::cuda::getCurrentCUDAStream()>>>(
v.packed_accessor32<scalar_t, 2>(),
a_values.packed_accessor32<scalar_t, 2>(),
a_offsets.packed_accessor32<index_t, 1>(),
output.packed_accessor32<scalar_t, 2>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
}
const int B = a_offsets.numel() - 1;
TORCH_CHECK(
B == 0 || v.size(0) % B == 0,
"B, ",
B,
" doesn't divide v.size(0), ",
v.size(0));
const int H = (B == 0) ? 1 : v.size(0) / B;
const int D = a_values.size(-1) / H;
auto output = at::empty({B * H, D}, v.options());
if (B > 0 && D > 0) {
const int block_dim_x =
std::min(div_round_up(D, kWarpSize) * kWarpSize, kMaxThreads);
const int block_dim_y = kMaxThreads / block_dim_x;
return {output};
AT_DISPATCH_INDEX_TYPES(
a_offsets.scalar_type(), "dense_vec_jagged_2d_bmm_kernel_1", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
a_values.scalar_type(),
"dense_vec_jagged_2d_bmm_kernel_2",
[&] {
dense_vec_jagged_2d_bmm<index_t, scalar_t>
<<<div_round_up(B * H, block_dim_y),
dim3(block_dim_x, block_dim_y),
0,
at::cuda::getCurrentCUDAStream()>>>(
v.packed_accessor32<scalar_t, 2>(),
a_values.packed_accessor32<scalar_t, 2>(),
a_offsets.packed_accessor32<index_t, 1>(),
output.packed_accessor32<scalar_t, 2>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_outputs) {
const auto saved = ctx->get_saved_variables();
auto savedItr = std::begin(saved);
const Tensor v = *savedItr++;
const Tensor a_values = *savedItr++;
const Tensor a_offsets = *savedItr++;
TORCH_CHECK(grad_outputs.size() == 1);
return output;
}
TENSOR_ON_CUDA_GPU(grad_outputs[0]);
TENSOR_ON_CUDA_GPU(a_values);
TENSOR_ON_CUDA_GPU(a_offsets);
TENSOR_ON_CUDA_GPU(v);
std::tuple<Tensor, Tensor> batched_dense_vec_jagged_2d_mul_backward(
const Tensor& grad_output,
const Tensor& v,
const Tensor& a_values,
const Tensor& a_offsets) {
TENSOR_ON_CUDA_GPU(grad_output);
TENSOR_ON_CUDA_GPU(a_values);
TENSOR_ON_CUDA_GPU(a_offsets);
TENSOR_ON_CUDA_GPU(v);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad_outputs[0].get_device());
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad_output.get_device());
const int B = a_offsets.numel() - 1;
const int D = grad_outputs[0].size(-1);
const int B = a_offsets.numel() - 1;
const int D = grad_output.size(-1);
Tensor a_values_grad = at::zeros_like(a_values);
Tensor v_grad = at::empty_like(v);
Tensor a_values_grad = at::zeros_like(a_values);
Tensor v_grad = at::empty_like(v);
if (B > 0 && D > 0) {
TORCH_CHECK(
v.size(0) % B == 0,
"B, ",
B,
" doesn't divide v.size(0), ",
v.size(0));
const int H = v.size(0) / B;
const int max_L = v.size(-1);
AT_DISPATCH_INDEX_TYPES(
a_offsets.scalar_type(),
"dense_vec_jagged_2d_bmm_baackward_kernel_1",
[&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
grad_outputs[0].scalar_type(),
"dense_vec_jagged_2d_bmm_baackward_kernel_2",
[&] {
int block_dim_x = std::min(
div_round_up(max_L, kWarpSize) * kWarpSize, kMaxThreads);
int block_dim_y = kMaxThreads / block_dim_x;
dense_vec_jagged_2d_transposed_bmm<index_t, scalar_t>
<<<div_round_up(B * H, block_dim_y),
dim3(block_dim_x, block_dim_y),
0,
at::cuda::getCurrentCUDAStream()>>>(
grad_outputs[0].packed_accessor32<scalar_t, 2>(),
a_values.packed_accessor32<scalar_t, 2>(),
a_offsets.packed_accessor32<index_t, 1>(),
v_grad.packed_accessor32<scalar_t, 2>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
block_dim_x = std::min(
div_round_up(D, kWarpSize) * kWarpSize, kMaxThreads);
block_dim_y = kMaxThreads / block_dim_x;
outer_prod_jagged_2d_output<index_t, scalar_t>
<<<div_round_up(B * H * max_L, block_dim_y),
dim3(block_dim_x, block_dim_y),
0,
at::cuda::getCurrentCUDAStream()>>>(
v.packed_accessor32<scalar_t, 2>(),
grad_outputs[0].packed_accessor32<scalar_t, 2>(),
a_offsets.packed_accessor32<index_t, 1>(),
a_values_grad.packed_accessor32<scalar_t, 2>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
} else {
v_grad.zero_();
}
if (B > 0 && D > 0) {
TORCH_CHECK(
v.size(0) % B == 0, "B, ", B, " doesn't divide v.size(0), ", v.size(0));
const int H = v.size(0) / B;
const int max_L = v.size(-1);
return {
v_grad,
a_values_grad,
torch::autograd::Variable(), // a_offsets
};
AT_DISPATCH_INDEX_TYPES(
a_offsets.scalar_type(),
"dense_vec_jagged_2d_bmm_backward_kernel_1",
[&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
grad_output.scalar_type(),
"dense_vec_jagged_2d_bmm_backward_kernel_2",
[&] {
int block_dim_x = std::min(
div_round_up(max_L, kWarpSize) * kWarpSize, kMaxThreads);
int block_dim_y = kMaxThreads / block_dim_x;
dense_vec_jagged_2d_transposed_bmm<index_t, scalar_t>
<<<div_round_up(B * H, block_dim_y),
dim3(block_dim_x, block_dim_y),
0,
at::cuda::getCurrentCUDAStream()>>>(
grad_output.packed_accessor32<scalar_t, 2>(),
a_values.packed_accessor32<scalar_t, 2>(),
a_offsets.packed_accessor32<index_t, 1>(),
v_grad.packed_accessor32<scalar_t, 2>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
block_dim_x = std::min(
div_round_up(D, kWarpSize) * kWarpSize, kMaxThreads);
block_dim_y = kMaxThreads / block_dim_x;
outer_prod_jagged_2d_output<index_t, scalar_t>
<<<div_round_up(B * H * max_L, block_dim_y),
dim3(block_dim_x, block_dim_y),
0,
at::cuda::getCurrentCUDAStream()>>>(
v.packed_accessor32<scalar_t, 2>(),
grad_output.packed_accessor32<scalar_t, 2>(),
a_offsets.packed_accessor32<index_t, 1>(),
a_values_grad.packed_accessor32<scalar_t, 2>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
} else {
v_grad.zero_();
}
};
///@ingroup jagged-tensor-ops-cuda
Tensor batched_dense_vec_jagged_2d_mul(
const Tensor& v,
const Tensor& a_values,
const Tensor& a_offsets) {
return BatchedDenseVecJagged2DMulGPUOp::apply(v, a_values, a_offsets)[0];
return {v_grad, a_values_grad};
}
} // namespace
Expand Down Expand Up @@ -2851,9 +2820,15 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
DISPATCH_TO_CUDA(
"jagged_dense_elementwise_mul",
fbgemm_gpu::jagged_dense_elementwise_mul_autograd);
DISPATCH_TO_CUDA(
"batched_dense_vec_jagged_2d_mul_forward",
fbgemm_gpu::batched_dense_vec_jagged_2d_mul_forward);
DISPATCH_TO_CUDA(
"batched_dense_vec_jagged_2d_mul_backward",
fbgemm_gpu::batched_dense_vec_jagged_2d_mul_backward);
DISPATCH_TO_CUDA(
"batched_dense_vec_jagged_2d_mul",
fbgemm_gpu::batched_dense_vec_jagged_2d_mul);
fbgemm_gpu::batched_dense_vec_jagged_2d_mul_autograd);
DISPATCH_TO_CUDA(
"jagged_index_select", fbgemm_gpu::jagged_index_select_2d_gpu);
DISPATCH_TO_CUDA(
Expand Down
66 changes: 66 additions & 0 deletions fbgemm_gpu/src/jagged_tensor_ops_autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,61 @@ class JaggedDenseMulOp : public torch::autograd::Function<JaggedDenseMulOp> {
}
};

// batched dense vector x jagged 2D tensor multiplication
// dense vector [B H, N]
// jagged tensor [B, N, H D] where N is jagged
class BatchedDenseVecJagged2DMulOp
: public torch::autograd::Function<BatchedDenseVecJagged2DMulOp> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const Tensor& v,
const Tensor& a_values,
const Tensor& a_offsets) {
ctx->save_for_backward({v, a_values, a_offsets});

static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow(
"fbgemm::batched_dense_vec_jagged_2d_mul_forward", "")
.typed<Tensor(
const Tensor& v,
const Tensor& a_values,
const Tensor& a_offsets)>();
Tensor output = op.call(v, a_values, a_offsets);

return {output};
}

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_outputs) {
const auto saved = ctx->get_saved_variables();
auto savedItr = std::begin(saved);
const Tensor v = *savedItr++;
const Tensor a_values = *savedItr++;
const Tensor a_offsets = *savedItr++;
TORCH_CHECK(grad_outputs.size() == 1);

static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow(
"fbgemm::batched_dense_vec_jagged_2d_mul_backward", "")
.typed<std::tuple<Tensor, Tensor>(
const Tensor& grad_output,
const Tensor& v,
const Tensor& a_values,
const Tensor& a_offsets)>();
auto outputs = op.call(grad_outputs[0], v, a_values, a_offsets);

return {
std::get<0>(outputs),
std::get<1>(outputs),
torch::autograd::Variable(), // a_offsets
};
}
};

} // namespace

///@ingroup jagged-tensor-ops-cpu
Expand Down Expand Up @@ -238,6 +293,14 @@ std::tuple<Tensor, std::vector<Tensor>> jagged_dense_elementwise_mul_autograd(
return {prod_values, x_offsets};
}

///@ingroup jagged-tensor-ops-cpu
Tensor batched_dense_vec_jagged_2d_mul_autograd(
const Tensor& v,
const Tensor& a_values,
const Tensor& a_offsets) {
return BatchedDenseVecJagged2DMulOp::apply(v, a_values, a_offsets)[0];
}

} // namespace fbgemm_gpu

TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) {
Expand All @@ -255,4 +318,7 @@ TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) {
m.impl(
"jagged_dense_elementwise_mul",
TORCH_FN(fbgemm_gpu::jagged_dense_elementwise_mul_autograd));
m.impl(
"batched_dense_vec_jagged_2d_mul",
TORCH_FN(fbgemm_gpu::batched_dense_vec_jagged_2d_mul_autograd));
}
Loading

0 comments on commit e749c95

Please sign in to comment.