Skip to content

Commit

Permalink
Refactor OptionalCUDAGuard -> CUDA_DEVICE_GUARD (pytorch#2270)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2270

- Refactor OptionalCUDAGuard -> CUDA_DEVICE_GUARD

Reviewed By: jianyuh

Differential Revision: D52820946

fbshipit-source-id: 7ddd564709fb13b54035af30c2d6f09f7cb1ed26
  • Loading branch information
q10 authored and facebook-github-bot committed Jan 18, 2024
1 parent 67047bf commit 9a3c5b2
Show file tree
Hide file tree
Showing 67 changed files with 125 additions and 248 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,8 @@ Tensor {{ ddesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
TENSOR_ON_CUDA_GPU(feature_requires_grad);
}

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(dev_weights.get_device());
CUDA_DEVICE_GUARD(dev_weights);

const auto T = D_offsets.size(0) - 1;
TORCH_CHECK_GT(T, 0);
// offsets = [B x T + 1]
Expand Down
3 changes: 1 addition & 2 deletions fbgemm_gpu/codegen/embedding_backward_split_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e
}
{%- endif %}

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(dev_weights.get_device());
CUDA_DEVICE_GUARD(dev_weights);

{%- if nobag and not is_index_select %}
auto max_D = D;
Expand Down
3 changes: 1 addition & 2 deletions fbgemm_gpu/codegen/embedding_bounds_check.cu
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,7 @@ void bounds_check_indices_cuda(
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
rows_per_table, indices, offsets, warning, weights, B_offsets);

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(rows_per_table.get_device());
CUDA_DEVICE_GUARD(rows_per_table);

const int32_t T = rows_per_table.size(0);
const int32_t total_B = offsets.size(0) - 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ Tensor pruned_hashmap_lookup_cuda(
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
indices, offsets, hash_table, hash_table_offsets);

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(indices.get_device());
CUDA_DEVICE_GUARD(indices);

auto dense_indices = at::empty_like(indices);
const int32_t T = hash_table_offsets.size(0) - 1;
const int32_t B = (offsets.size(0) - 1) / T;
Expand Down Expand Up @@ -179,8 +179,8 @@ Tensor pruned_array_lookup_cuda(
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
indices, offsets, index_remappings, index_remappings_offsets);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(indices.get_device());
CUDA_DEVICE_GUARD(indices);
auto dense_indices = at::empty_like(indices);
const int32_t T = index_remappings_offsets.size(0) - 1;
TORCH_CHECK(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_weights, dev_weights);
TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_locations, dev_weights);

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(dev_weights.get_device());
CUDA_DEVICE_GUARD(dev_weights);

// kernels assume indices are contiguous.
indices = indices.contiguous();
Expand Down
3 changes: 1 addition & 2 deletions fbgemm_gpu/codegen/embedding_forward_split_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,7 @@ batch_index_select_dim0_codegen_forward_cuda(
}
{%- endif %}

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(dev_weights.get_device());
CUDA_DEVICE_GUARD(dev_weights);

{%- if not nobag %}
int32_t T = D_offsets.numel() - 1;
Expand Down
3 changes: 1 addition & 2 deletions fbgemm_gpu/codegen/embedding_optimizer_split_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,7 @@ void split_embedding_{{ optimizer }}_update(
return;
}

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(dev_weights.get_device());
CUDA_DEVICE_GUARD(dev_weights);

// Flatten dev_weights because it is currrently 2D
dev_weights = dev_weights.flatten();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ void embedding_inplace_update_cuda(
lxu_cache_weights,
lxu_cache_locations);

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(dev_weights.get_device());
CUDA_DEVICE_GUARD(dev_weights);

const int64_t N = update_row_idx.numel();
if (N == 0) {
Expand Down Expand Up @@ -226,9 +225,8 @@ Tensor pruned_array_lookup_from_row_idx_cuda(
update_table_indices,
index_remappings,
index_remappings_offsets);
CUDA_DEVICE_GUARD(update_table_indices);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(update_table_indices.get_device());
auto dense_indices = at::empty_like(update_row_indices);
const int32_t T = index_remappings_offsets.size(0) - 1;
Expand Down
12 changes: 3 additions & 9 deletions fbgemm_gpu/src/histogram_binning_calibration_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ std::tuple<Tensor, Tensor> histogram_binning_calibration_cuda(
TENSOR_ON_CUDA_GPU(bin_num_examples);
TENSOR_ON_CUDA_GPU(bin_num_positives);
TORCH_CHECK_EQ(bin_num_examples.numel(), bin_num_positives.numel());

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(logit.get_device());
CUDA_DEVICE_GUARD(logit);

Tensor calibrated_prediction = at::empty_like(logit);
Tensor bin_ids = at::empty({logit.numel()}, logit.options().dtype(at::kLong));
Expand Down Expand Up @@ -188,9 +186,7 @@ std::tuple<Tensor, Tensor> histogram_binning_calibration_by_feature_cuda(
TENSOR_ON_CUDA_GPU(bin_num_examples);
TENSOR_ON_CUDA_GPU(bin_num_positives);
TORCH_CHECK_EQ(bin_num_examples.numel(), bin_num_positives.numel());
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(logit.get_device());
CUDA_DEVICE_GUARD(logit);
// Convert lengths to offsets for better handling on GPUs.
const auto segment_lengths_packed = segment_lengths.contiguous();
Expand Down Expand Up @@ -351,9 +347,7 @@ generic_histogram_binning_calibration_by_feature_cuda(
TORCH_CHECK(
bin_num_examples.numel() ==
(num_segments + 1) * (bin_boundaries.numel() + 1));
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(logit.get_device());
CUDA_DEVICE_GUARD(logit);
// Convert lengths to offsets for better handling on GPUs.
const auto segment_lengths_packed = segment_lengths.contiguous();
Expand Down
3 changes: 1 addition & 2 deletions fbgemm_gpu/src/input_combine_ops/input_combine.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ std::tuple<Tensor, Tensor, Tensor> tbe_input_combine_with_length_cuda(
const uint64_t max_list_size,
const c10::DeviceIndex& device) {
constexpr uint32_t IS_LONG_NUM_BITS = 32;
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(device);
at::cuda::OptionalCUDAGuard device_guard(device);

// combined_indices and combined_legnths are int tensors
const auto int_options = at::TensorOptions().dtype(at::kInt).device(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,7 @@ std::tuple<Tensor, Tensor> batched_dense_vec_jagged_2d_mul_backward(
const Tensor& a_values,
const Tensor& a_offsets) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(grad_output, a_values, a_offsets, v);

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad_output.get_device());
CUDA_DEVICE_GUARD(grad_output);

const int B = a_offsets.numel() - 1;
const int D = grad_output.size(-1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ Tensor batched_dense_vec_jagged_2d_mul_forward(
const Tensor& a_values,
const Tensor& a_offsets) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(v, a_values, a_offsets);

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(v.get_device());
CUDA_DEVICE_GUARD(v);

const int B = a_offsets.numel() - 1;
TORCH_CHECK(
Expand Down
3 changes: 1 addition & 2 deletions fbgemm_gpu/src/jagged_tensor_ops/dense_to_jagged_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ Tensor dense_to_jagged_forward(
auto values = at::empty_symint({total_L_computed, D}, dense.options());
auto output = at::empty_like(values);

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(dense.get_device());
CUDA_DEVICE_GUARD(dense);

#define DISPATCH_DENSE_TO_JAGGED_CASE(TYPE) \
AT_DISPATCH_CASE(TYPE, [&] { \
Expand Down
4 changes: 1 addition & 3 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_bmm_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,7 @@ Tensor jagged_dense_bmm_forward_cuda(
const Tensor& y,
const int64_t max_L) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(x_values, x_offsets, y);

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(x_values.get_device());
CUDA_DEVICE_GUARD(x_values);

const int B = x_offsets.numel() - 1;
const int M = x_values.size(-1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,7 @@ Tensor jagged_dense_dense_elementwise_add_jagged_output_forward(
TORCH_CHECK_EQ(dense_0.sizes(), dense_1.sizes());
auto output = at::empty_like(x_values);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(dense_0.get_device());
CUDA_DEVICE_GUARD(dense_0);
if (x_values.scalar_type() == at::ScalarType::BFloat16 &&
dense_0.scalar_type() == at::ScalarType::BFloat16 &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,7 @@ std::tuple<Tensor, Tensor> jagged_dense_elementwise_mul_backward(
const std::vector<Tensor>& x_offsets,
const Tensor& y,
const Tensor& x_values) {
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad_output.get_device());
CUDA_DEVICE_GUARD(grad_output);

Tensor x_values_grad = at::empty_like(grad_output);
Tensor y_grad = at::empty_like(y);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ Tensor jagged_dense_elementwise_mul_forward(
const Tensor& x_values,
const std::vector<Tensor>& x_offsets,
const Tensor& y) {
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(x_values.get_device());
CUDA_DEVICE_GUARD(x_values);

Tensor output = at::empty_like(x_values);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ Tensor jagged_index_add_2d_forward_cuda(
const int64_t num_output_rows) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
values, indices, input_offsets, output_offsets);

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(values.get_device());
CUDA_DEVICE_GUARD(values);

auto num_cols = values.size(1);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ Tensor jagged_index_select_2d_forward_cuda(
const int64_t num_dense_output_rows) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
values, indices, input_offsets, output_offsets);

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(values.get_device());
CUDA_DEVICE_GUARD(values);

auto num_cols = values.size(1);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,7 @@ Tensor jagged_jagged_bmm_forward_cuda(
const Tensor& offsets,
const int64_t max_L) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(x_values, y_values, offsets);

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(x_values.get_device());
CUDA_DEVICE_GUARD(x_values);

const int B = offsets.numel() - 1;
const int M = x_values.size(-1);
Expand Down
4 changes: 1 addition & 3 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_softmax_backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,7 @@ Tensor jagged_softmax_backward_cuda(
const Tensor& offsets,
const int64_t max_L) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(grad_output, output, offsets);

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad_output.get_device());
CUDA_DEVICE_GUARD(grad_output);

const auto B = offsets.numel() - 1;
const auto D = grad_output.size(1);
Expand Down
4 changes: 1 addition & 3 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_softmax_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,7 @@ Tensor jagged_softmax_forward_cuda(
const Tensor& offsets,
const int64_t max_L) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(values, offsets);

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(values.get_device());
CUDA_DEVICE_GUARD(values);

const auto B = offsets.numel() - 1;
const auto D = values.size(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ at::Tensor jagged_to_padded_dense_backward(
const std::vector<Tensor>& offsets,
at::SymInt total_L) {
auto grad_padded_values = grad_output;
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad_padded_values.get_device());
CUDA_DEVICE_GUARD(grad_padded_values);

// Canonicalize padded_values by unsqueeze the last dim if the inner dense
// dimension is 1 and folded.
Expand Down
19 changes: 6 additions & 13 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ at::Tensor jagged_to_padded_dense_forward(
max_lengths.size(),
" != num_jagged_dim, ",
num_jagged_dim);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(values.get_device());
CUDA_DEVICE_GUARD(values);

const Tensor values_canonicalized = values.view(
{values.size(0),
Expand Down Expand Up @@ -83,8 +82,7 @@ std::vector<Tensor> stacked_jagged_1d_to_dense_gpu(
int64_t padding_value) {
TORCH_CHECK(values.dim() == 1);
TORCH_CHECK(lengths.dim() == 2);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(values.get_device());
CUDA_DEVICE_GUARD(values);

const auto lengths_contig = lengths.contiguous();
int32_t B = lengths.size(1);
Expand Down Expand Up @@ -138,8 +136,7 @@ stacked_jagged_2d_to_dense_forward_cuda(
int64_t padding_value) {
TORCH_CHECK(values.dim() == 2);
TORCH_CHECK(lengths.dim() == 2);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(values.get_device());
CUDA_DEVICE_GUARD(values);

const auto lengths_contig = lengths.contiguous();
int32_t D = values.size(1);
Expand Down Expand Up @@ -194,8 +191,7 @@ Tensor stacked_jagged_2d_to_dense_backward_cuda(
const std::vector<Tensor>& grad_padded_values_per_key,
const std::vector<Tensor>& offsets_tensor_per_key,
const std::vector<int64_t>& offset_per_key) {
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad_padded_values_per_key[0].get_device());
CUDA_DEVICE_GUARD(grad_padded_values_per_key[0]);

auto grad_values =
at::zeros({total_L, D}, grad_padded_values_per_key[0].options());
Expand Down Expand Up @@ -321,8 +317,7 @@ class JaggedDenseAddJaggedOutputGPUOp

auto output = at::empty_like(x_values);

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(dense.get_device());
CUDA_DEVICE_GUARD(dense);

AT_DISPATCH_SWITCH(
x_values.scalar_type(),
Expand Down Expand Up @@ -364,9 +359,7 @@ class JaggedDenseAddJaggedOutputGPUOp
auto offsets = ctx->get_saved_variables();
auto dense_shape = ctx->saved_data["dense_shape"].toIntVector();
TORCH_CHECK(grad_outputs.size() == 1);

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad_outputs[0].get_device());
CUDA_DEVICE_GUARD(grad_outputs[0]);

Tensor dense_values_grad = jagged_to_padded_dense_forward(
grad_outputs[0],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,7 @@ class KeyedJaggedIndexSelectDim1GPUOp
"weights size and values size must be the same");
}

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(values.get_device());
CUDA_DEVICE_GUARD(values);

const int num_batches = lengths.numel() / batch_size;
const int num_output_lengths = num_batches * indices.numel();
Expand Down Expand Up @@ -380,8 +379,7 @@ class KeyedJaggedIndexSelectDim1GPUOp
int64_t output_batch_size = ctx->saved_data["batch_size"].toInt();
int64_t num_batches = ctx->saved_data["num_batches"].toInt();
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad.get_device());
CUDA_DEVICE_GUARD(grad);
Tensor grad_input = at::zeros({num_outputs}, grad.options());
auto grid_size = cuda_calc_xblock_count(grad.numel(), kMaxThreads);
Expand Down
9 changes: 3 additions & 6 deletions fbgemm_gpu/src/layout_transform_ops/layout_transform_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ Tensor recat_embedding_grad_output_cuda(
const std::vector<int64_t>& num_features_per_rank) {
TENSOR_ON_CUDA_GPU(grad_output);

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad_output.get_device());
CUDA_DEVICE_GUARD(grad_output);

TORCH_CHECK(grad_output.is_contiguous());
const auto B_local = grad_output.size(0);
Expand Down Expand Up @@ -82,8 +81,7 @@ Tensor recat_embedding_grad_output_mixed_D_cuda(
TENSOR_ON_CUDA_GPU(grad_output);
TORCH_CHECK(grad_output.is_contiguous());

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad_output.get_device());
CUDA_DEVICE_GUARD(grad_output);

const auto B_local = grad_output.size(0);
const auto global_dim_sum = at::sum_integers(dim_sum_per_rank);
Expand Down Expand Up @@ -129,8 +127,7 @@ Tensor recat_embedding_grad_output_mixed_D_batch_cuda(
grad_output, dim_sum_per_rank, cumsum_dim_sum_per_rank);
TORCH_CHECK(grad_output.is_contiguous());

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad_output.get_device());
CUDA_DEVICE_GUARD(grad_output);

const auto B_local = grad_output.size(0);
Tensor sharded_grad_output =
Expand Down
Loading

0 comments on commit 9a3c5b2

Please sign in to comment.