Skip to content

Commit

Permalink
uvm_cache_stats for direct mapped (pytorch#1951)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1951

D40518654 introduced `uvm_cache_stats` to provide cache metrics for FBGEMM 32way cache.
This diff expands its usage to also provide cache metrics for direct mapped cache.

Reviewed By: sryap

Differential Revision: D48023956

fbshipit-source-id: 5ee35c24ef1ae8ad3d70d167ee55784997b0c5a1
  • Loading branch information
SungMinCho authored and facebook-github-bot committed Aug 23, 2023
1 parent ddbfa97 commit f318e6a
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 9 deletions.
8 changes: 6 additions & 2 deletions fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ void direct_mapped_lru_cache_populate_byte_cuda(
int64_t time_stamp,
at::Tensor lru_state,
at::Tensor lxu_cache_miss_timestamp,
int64_t row_alignment);
int64_t row_alignment,
bool gather_cache_stats,
c10::optional<at::Tensor> uvm_cache_stats);

///@ingroup table-batched-embed-cuda
/// LFU cache: fetch the rows corresponding to `linear_cache_indices` from
Expand Down Expand Up @@ -174,7 +176,9 @@ at::Tensor emulate_cache_miss(
at::Tensor direct_mapped_lxu_cache_lookup_cuda(
at::Tensor linear_cache_indices,
at::Tensor lxu_cache_state,
int64_t invalid_index);
int64_t invalid_index,
bool gather_cache_stats,
c10::optional<at::Tensor> uvm_cache_stats);

//////@ingroup table-batched-embed-cuda
/// Flush the cache: store the weights from the cache to the backing storage.
Expand Down
104 changes: 99 additions & 5 deletions fbgemm_gpu/src/split_embeddings_cache_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <cub/device/device_radix_sort.cuh>
#include <cub/device/device_run_length_encode.cuh>
#include <cub/device/device_select.cuh>
#include <cub/block/block_reduce.cuh>
#include "fbgemm_gpu/cub_namespace_postfix.cuh"
// clang-format on

Expand Down Expand Up @@ -728,11 +729,24 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lru_cache_find_uncached_kernel
lxu_cache_state,
const int64_t time_stamp,
at::PackedTensorAccessor32<int64_t, 2, at::RestrictPtrTraits> lru_state,
const bool gather_cache_stats,
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
uvm_cache_stats,
at::PackedTensorAccessor32<int64_t, 2, at::RestrictPtrTraits>
lxu_cache_miss_timestamp) {
const int32_t N = linear_cache_indices.size(0);
const int32_t C = lxu_cache_state.size(0);
if (gather_cache_stats) {
if (blockIdx.x == 0 && threadIdx.x == 0) {
atomicAdd(
&uvm_cache_stats[uvm_cache_stats_index::num_calls], 1); // N_called.
atomicAdd(
&uvm_cache_stats[uvm_cache_stats_index::num_requested_indices],
N); // N_requested_indices.
}
}
CUDA_KERNEL_LOOP(n, N) {
int64_t idx = linear_cache_indices[n];
if (idx == max_indices) {
Expand Down Expand Up @@ -879,7 +893,9 @@ Tensor direct_mapped_lru_cache_find_uncached_cuda(
Tensor lxu_cache_state,
int64_t time_stamp,
Tensor lru_state,
Tensor lxu_cache_miss_timestamp) {
Tensor lxu_cache_miss_timestamp,
bool gather_cache_stats,
Tensor uvm_cache_stats) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
linear_cache_indices,
lxu_cache_state,
Expand Down Expand Up @@ -914,6 +930,9 @@ Tensor direct_mapped_lru_cache_find_uncached_cuda(
.packed_accessor32<int64_t, 2, at::RestrictPtrTraits>(),
time_stamp,
lru_state.packed_accessor32<int64_t, 2, at::RestrictPtrTraits>(),
gather_cache_stats,
uvm_cache_stats
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
lxu_cache_miss_timestamp
.packed_accessor32<int64_t, 2, at::RestrictPtrTraits>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
Expand Down Expand Up @@ -1418,6 +1437,9 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lru_cache_insert_byte_kernel(
at::PackedTensorAccessor32<int64_t, 2, at::RestrictPtrTraits>
lxu_cache_miss_timestamp,
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> cache_sets,
const bool gather_cache_stats,
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
uvm_cache_stats,
const int64_t row_alignment) {
const int32_t N = cache_sets.size(0);
Expand Down Expand Up @@ -1445,6 +1467,24 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lru_cache_insert_byte_kernel(
// continue;
// }
if (gather_cache_stats && threadIdx.x == 0) {
// We are using this slot for a slightly different purpose.
// In 32 way:
// UVM traffic for insert
// = # of inserted rows
// = # of unique misses - # of unique misses that were not inserted
// = uvm_cache_stats_index::num_unique_misses
// - uvm_cache_stats_index::num_conflict_unique_misses
// In Direct Mapped (here):
// UVM traffic for insert
// = # of inserted rows
// = uvm_cache_stats_index::num_conflict_unique_misses
// (just store here directly)
atomicAdd(
&uvm_cache_stats[uvm_cache_stats_index::num_conflict_unique_misses],
1);
}
// insert the index in the buffer into our only slot
const int32_t insert_slot = 0;
Expand Down Expand Up @@ -1568,6 +1608,8 @@ void direct_mapped_lru_cache_insert_byte_cuda(
Tensor linear_cache_indices,
Tensor lxu_cache_miss_timestamp,
Tensor cache_sets,
bool gather_cache_stats,
Tensor uvm_cache_stats,
int64_t row_alignment) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
weights,
Expand Down Expand Up @@ -1618,6 +1660,9 @@ void direct_mapped_lru_cache_insert_byte_cuda(
lxu_cache_miss_timestamp
.packed_accessor32<int64_t, 2, at::RestrictPtrTraits>(),
cache_sets.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
gather_cache_stats,
uvm_cache_stats
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
row_alignment);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
Expand Down Expand Up @@ -1729,7 +1774,9 @@ DLL_PUBLIC void direct_mapped_lru_cache_populate_byte_cuda(
int64_t time_stamp,
Tensor lru_state,
Tensor lxu_cache_miss_timestamp,
int64_t row_alignment) {
int64_t row_alignment,
bool gather_cache_stats,
c10::optional<Tensor> uvm_cache_stats) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
weights,
cache_hash_size_cumsum,
Expand All @@ -1743,6 +1790,14 @@ DLL_PUBLIC void direct_mapped_lru_cache_populate_byte_cuda(
lru_state,
lxu_cache_miss_timestamp);
if (gather_cache_stats) {
TORCH_CHECK(uvm_cache_stats.has_value());
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
uvm_cache_stats, lxu_cache_weights);
}
auto uvm_cache_stats_ = uvm_cache_stats.value_or(
at::empty({0}, weights.options().dtype(at::kInt)));
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(weights.get_device());
Expand Down Expand Up @@ -1785,7 +1840,9 @@ DLL_PUBLIC void direct_mapped_lru_cache_populate_byte_cuda(
lxu_cache_state,
time_stamp,
lru_state,
lxu_cache_miss_timestamp);
lxu_cache_miss_timestamp,
gather_cache_stats,
uvm_cache_stats_);
// insert caching weights
direct_mapped_lru_cache_insert_byte_cuda(
Expand All @@ -1802,6 +1859,8 @@ DLL_PUBLIC void direct_mapped_lru_cache_populate_byte_cuda(
linear_cache_indices,
lxu_cache_miss_timestamp,
cache_sets,
gather_cache_stats,
uvm_cache_stats_,
row_alignment);
}
Expand Down Expand Up @@ -2617,10 +2676,16 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lxu_cache_lookup_kernel(
lxu_cache_state,
int64_t invalid_index,
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
lxu_cache_locations) {
lxu_cache_locations,
const bool gather_cache_stats,
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
uvm_cache_stats) {
const int32_t C = lxu_cache_state.size(0);
const int32_t N = linear_cache_indices.size(0);
int32_t n_indices = 0;
int32_t n_hits = 0;
CUDA_KERNEL_LOOP(n, N) {
int32_t cache_location = kCacheLocationMissing;
const auto slot = 0;
Expand All @@ -2631,13 +2696,29 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lxu_cache_lookup_kernel(
}
const int32_t cache_set = cache_slot(idx, C);
n_indices++;
const bool found =
(::__ldg((&lxu_cache_state[cache_set][0]) + slot) == idx);
if (found) {
cache_location = cache_set;
n_hits++;
}
lxu_cache_locations[n] = cache_location;
}
if (gather_cache_stats) {
typedef cub::BlockReduce<int32_t, kMaxThreads> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp;
int32_t conflict_miss = n_indices - n_hits;
int32_t conflict_miss_sum = BlockReduce(temp).Sum(conflict_miss);
if (threadIdx.x == 0) {
atomicAdd(
&uvm_cache_stats[uvm_cache_stats_index::num_conflict_misses],
conflict_miss_sum);
}
}
}
} // namespace
Expand Down Expand Up @@ -2748,10 +2829,20 @@ DLL_PUBLIC void lxu_cache_locations_update_cuda(
DLL_PUBLIC Tensor direct_mapped_lxu_cache_lookup_cuda(
Tensor linear_cache_indices,
Tensor lxu_cache_state,
int64_t invalid_index) {
int64_t invalid_index,
bool gather_cache_stats,
c10::optional<Tensor> uvm_cache_stats) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
linear_cache_indices, lxu_cache_state);
if (gather_cache_stats) {
TORCH_CHECK(uvm_cache_stats.has_value());
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
uvm_cache_stats, linear_cache_indices);
}
auto uvm_cache_stats_ = uvm_cache_stats.value_or(
at::empty({0}, linear_cache_indices.options().dtype(at::kInt)));
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(linear_cache_indices.get_device());
Expand Down Expand Up @@ -2780,6 +2871,9 @@ DLL_PUBLIC Tensor direct_mapped_lxu_cache_lookup_cuda(
.packed_accessor32<int64_t, 2, at::RestrictPtrTraits>(),
invalid_index,
lxu_cache_locations
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
gather_cache_stats,
uvm_cache_stats_
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
Expand Down
4 changes: 2 additions & 2 deletions fbgemm_gpu/src/split_table_batched_embeddings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"lru_cache_populate_byte(Tensor weights, Tensor hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, int time_stamp, Tensor(c!) lru_state, int row_alignment=16, bool gather_cache_stats=False, Tensor(d!)? uvm_cache_stats=None) -> ()");
DISPATCH_TO_CUDA("lru_cache_populate_byte", lru_cache_populate_byte_cuda);
m.def(
"direct_mapped_lru_cache_populate_byte(Tensor weights, Tensor hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, int time_stamp, Tensor(c!) lru_state, Tensor(d!) lxu_cache_miss_timestamp, int row_alignment=16) -> ()");
"direct_mapped_lru_cache_populate_byte(Tensor weights, Tensor hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, int time_stamp, Tensor(c!) lru_state, Tensor(d!) lxu_cache_miss_timestamp, int row_alignment=16, bool gather_cache_stats=False, Tensor(e!)? uvm_cache_stats=None) -> ()");
DISPATCH_TO_CUDA(
"direct_mapped_lru_cache_populate_byte",
direct_mapped_lru_cache_populate_byte_cuda);
Expand All @@ -45,7 +45,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"lxu_cache_lookup(Tensor linear_cache_indices, Tensor lxu_cache_state, int invalid_index = -1, bool gather_cache_stats=False, Tensor(a!)? uvm_cache_stats=None) -> Tensor");
DISPATCH_TO_CUDA("lxu_cache_lookup", lxu_cache_lookup_cuda);
m.def(
"direct_mapped_lxu_cache_lookup(Tensor linear_cache_indices, Tensor lxu_cache_state, int invalid_index = -1) -> Tensor");
"direct_mapped_lxu_cache_lookup(Tensor linear_cache_indices, Tensor lxu_cache_state, int invalid_index = -1, bool gather_cache_stats=False, Tensor(a!)? uvm_cache_stats=None) -> Tensor");
DISPATCH_TO_CUDA(
"direct_mapped_lxu_cache_lookup", direct_mapped_lxu_cache_lookup_cuda);
m.def(
Expand Down

0 comments on commit f318e6a

Please sign in to comment.