Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

uvm_cache_stats for direct mapped #1951

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ void direct_mapped_lru_cache_populate_byte_cuda(
int64_t time_stamp,
at::Tensor lru_state,
at::Tensor lxu_cache_miss_timestamp,
int64_t row_alignment);
int64_t row_alignment,
bool gather_cache_stats,
c10::optional<at::Tensor> uvm_cache_stats);

///@ingroup table-batched-embed-cuda
/// LFU cache: fetch the rows corresponding to `linear_cache_indices` from
Expand Down Expand Up @@ -174,7 +176,9 @@ at::Tensor emulate_cache_miss(
at::Tensor direct_mapped_lxu_cache_lookup_cuda(
at::Tensor linear_cache_indices,
at::Tensor lxu_cache_state,
int64_t invalid_index);
int64_t invalid_index,
bool gather_cache_stats,
c10::optional<at::Tensor> uvm_cache_stats);

//////@ingroup table-batched-embed-cuda
/// Flush the cache: store the weights from the cache to the backing storage.
Expand Down
102 changes: 96 additions & 6 deletions fbgemm_gpu/src/split_embeddings_cache_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <cub/device/device_radix_sort.cuh>
#include <cub/device/device_run_length_encode.cuh>
#include <cub/device/device_select.cuh>
#include <cub/block/block_reduce.cuh>
#include "fbgemm_gpu/cub_namespace_postfix.cuh"
// clang-format on

Expand Down Expand Up @@ -742,11 +743,24 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lru_cache_find_uncached_kernel
lxu_cache_state,
const int64_t time_stamp,
pta::PackedTensorAccessor32<int64_t, 2, at::RestrictPtrTraits> lru_state,
const bool gather_cache_stats,
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
uvm_cache_stats,
pta::PackedTensorAccessor32<int64_t, 2, at::RestrictPtrTraits>
lxu_cache_miss_timestamp) {
const int32_t N = linear_cache_indices.size(0);
const int32_t C = lxu_cache_state.size(0);

if (gather_cache_stats) {
if (blockIdx.x == 0 && threadIdx.x == 0) {
atomicAdd(
&uvm_cache_stats[uvm_cache_stats_index::num_calls], 1); // N_called.
atomicAdd(
&uvm_cache_stats[uvm_cache_stats_index::num_requested_indices],
N); // N_requested_indices.
}
}

CUDA_KERNEL_LOOP(n, N) {
int64_t idx = linear_cache_indices[n];
if (idx == max_indices) {
Expand Down Expand Up @@ -893,7 +907,9 @@ Tensor direct_mapped_lru_cache_find_uncached_cuda(
Tensor lxu_cache_state,
int64_t time_stamp,
Tensor lru_state,
Tensor lxu_cache_miss_timestamp) {
Tensor lxu_cache_miss_timestamp,
bool gather_cache_stats,
Tensor uvm_cache_stats) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
linear_cache_indices,
lxu_cache_state,
Expand Down Expand Up @@ -929,6 +945,8 @@ Tensor direct_mapped_lru_cache_find_uncached_cuda(
MAKE_PTA_WITH_NAME(func_name, lxu_cache_state, int64_t, 2, 32),
time_stamp,
MAKE_PTA_WITH_NAME(func_name, lru_state, int64_t, 2, 32),
gather_cache_stats,
MAKE_PTA_WITH_NAME(func_name, uvm_cache_stats, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, lxu_cache_miss_timestamp, int64_t, 2, 32));
C10_CUDA_KERNEL_LAUNCH_CHECK();
Expand Down Expand Up @@ -1431,6 +1449,9 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lru_cache_insert_byte_kernel(
pta::PackedTensorAccessor32<int64_t, 2, at::RestrictPtrTraits>
lxu_cache_miss_timestamp,
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> cache_sets,
const bool gather_cache_stats,
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
uvm_cache_stats,
const int64_t row_alignment) {
const int32_t N = cache_sets.size(0);

Expand Down Expand Up @@ -1458,6 +1479,24 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lru_cache_insert_byte_kernel(
// continue;
// }

if (gather_cache_stats && threadIdx.x == 0) {
// We are using this slot for a slightly different purpose.
// In 32 way:
// UVM traffic for insert
// = # of inserted rows
// = # of unique misses - # of unique misses that were not inserted
// = uvm_cache_stats_index::num_unique_misses
// - uvm_cache_stats_index::num_conflict_unique_misses
// In Direct Mapped (here):
// UVM traffic for insert
// = # of inserted rows
// = uvm_cache_stats_index::num_conflict_unique_misses
// (just store here directly)
atomicAdd(
&uvm_cache_stats[uvm_cache_stats_index::num_conflict_unique_misses],
1);
}

// insert the index in the buffer into our only slot
const int32_t insert_slot = 0;

Expand Down Expand Up @@ -1579,6 +1618,8 @@ void direct_mapped_lru_cache_insert_byte_cuda(
Tensor linear_cache_indices,
Tensor lxu_cache_miss_timestamp,
Tensor cache_sets,
bool gather_cache_stats,
Tensor uvm_cache_stats,
int64_t row_alignment) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
weights,
Expand Down Expand Up @@ -1628,6 +1669,8 @@ void direct_mapped_lru_cache_insert_byte_cuda(
MAKE_PTA_WITH_NAME(
func_name, lxu_cache_miss_timestamp, int64_t, 2, 32),
MAKE_PTA_WITH_NAME(func_name, cache_sets, int32_t, 1, 32),
gather_cache_stats,
MAKE_PTA_WITH_NAME(func_name, uvm_cache_stats, int32_t, 1, 32),
row_alignment);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
Expand Down Expand Up @@ -1739,7 +1782,9 @@ DLL_PUBLIC void direct_mapped_lru_cache_populate_byte_cuda(
int64_t time_stamp,
Tensor lru_state,
Tensor lxu_cache_miss_timestamp,
int64_t row_alignment) {
int64_t row_alignment,
bool gather_cache_stats,
c10::optional<Tensor> uvm_cache_stats) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
weights,
cache_hash_size_cumsum,
Expand All @@ -1753,6 +1798,14 @@ DLL_PUBLIC void direct_mapped_lru_cache_populate_byte_cuda(
lru_state,
lxu_cache_miss_timestamp);

if (gather_cache_stats) {
TORCH_CHECK(uvm_cache_stats.has_value());
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
uvm_cache_stats, lxu_cache_weights);
}
auto uvm_cache_stats_ = uvm_cache_stats.value_or(
at::empty({0}, weights.options().dtype(at::kInt)));

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(weights.get_device());

Expand Down Expand Up @@ -1795,7 +1848,9 @@ DLL_PUBLIC void direct_mapped_lru_cache_populate_byte_cuda(
lxu_cache_state,
time_stamp,
lru_state,
lxu_cache_miss_timestamp);
lxu_cache_miss_timestamp,
gather_cache_stats,
uvm_cache_stats_);

// insert caching weights
direct_mapped_lru_cache_insert_byte_cuda(
Expand All @@ -1812,6 +1867,8 @@ DLL_PUBLIC void direct_mapped_lru_cache_populate_byte_cuda(
linear_cache_indices,
lxu_cache_miss_timestamp,
cache_sets,
gather_cache_stats,
uvm_cache_stats_,
row_alignment);
}

Expand Down Expand Up @@ -2632,10 +2689,16 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lxu_cache_lookup_kernel(
lxu_cache_state,
int64_t invalid_index,
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
lxu_cache_locations) {
lxu_cache_locations,
const bool gather_cache_stats,
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
uvm_cache_stats) {
const int32_t C = lxu_cache_state.size(0);
const int32_t N = linear_cache_indices.size(0);

int32_t n_indices = 0;
int32_t n_hits = 0;

CUDA_KERNEL_LOOP(n, N) {
int32_t cache_location = kCacheLocationMissing;
const auto slot = 0;
Expand All @@ -2646,13 +2709,29 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lxu_cache_lookup_kernel(
}

const int32_t cache_set = cache_slot(idx, C);
n_indices++;
const bool found =
(::__ldg((&lxu_cache_state[cache_set][0]) + slot) == idx);
if (found) {
cache_location = cache_set;
n_hits++;
}
lxu_cache_locations[n] = cache_location;
}

if (gather_cache_stats) {
typedef cub::BlockReduce<int32_t, kMaxThreads> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp;

const int32_t conflict_miss = n_indices - n_hits;
const int32_t conflict_miss_sum = BlockReduce(temp).Sum(conflict_miss);

if (threadIdx.x == 0) {
atomicAdd(
&uvm_cache_stats[uvm_cache_stats_index::num_conflict_misses],
conflict_miss_sum);
}
}
}

} // namespace
Expand Down Expand Up @@ -2764,9 +2843,18 @@ DLL_PUBLIC void lxu_cache_locations_update_cuda(
DLL_PUBLIC Tensor direct_mapped_lxu_cache_lookup_cuda(
Tensor linear_cache_indices,
Tensor lxu_cache_state,
int64_t invalid_index) {
int64_t invalid_index,
bool gather_cache_stats,
c10::optional<Tensor> uvm_cache_stats) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
linear_cache_indices, lxu_cache_state);
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(uvm_cache_stats, lxu_cache_state);

if (gather_cache_stats) {
TORCH_CHECK(uvm_cache_stats.has_value());
}
auto uvm_cache_stats_ = uvm_cache_stats.value_or(
at::empty({0}, linear_cache_indices.options().dtype(at::kInt)));

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(linear_cache_indices.get_device());
Expand Down Expand Up @@ -2796,7 +2884,9 @@ DLL_PUBLIC Tensor direct_mapped_lxu_cache_lookup_cuda(
MAKE_PTA_WITH_NAME(func_name, linear_cache_indices, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, lxu_cache_state, int64_t, 2, 32),
invalid_index,
MAKE_PTA_WITH_NAME(func_name, lxu_cache_locations, int32_t, 1, 32));
MAKE_PTA_WITH_NAME(func_name, lxu_cache_locations, int32_t, 1, 32),
gather_cache_stats,
MAKE_PTA_WITH_NAME(func_name, uvm_cache_stats_, int32_t, 1, 32));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});

Expand Down
4 changes: 2 additions & 2 deletions fbgemm_gpu/src/split_table_batched_embeddings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"lru_cache_populate_byte(Tensor weights, Tensor hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, int time_stamp, Tensor(c!) lru_state, int row_alignment=16, bool gather_cache_stats=False, Tensor(d!)? uvm_cache_stats=None) -> ()");
DISPATCH_TO_CUDA("lru_cache_populate_byte", lru_cache_populate_byte_cuda);
m.def(
"direct_mapped_lru_cache_populate_byte(Tensor weights, Tensor hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, int time_stamp, Tensor(c!) lru_state, Tensor(d!) lxu_cache_miss_timestamp, int row_alignment=16) -> ()");
"direct_mapped_lru_cache_populate_byte(Tensor weights, Tensor hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, int time_stamp, Tensor(c!) lru_state, Tensor(d!) lxu_cache_miss_timestamp, int row_alignment=16, bool gather_cache_stats=False, Tensor(e!)? uvm_cache_stats=None) -> ()");
DISPATCH_TO_CUDA(
"direct_mapped_lru_cache_populate_byte",
direct_mapped_lru_cache_populate_byte_cuda);
Expand All @@ -45,7 +45,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"lxu_cache_lookup(Tensor linear_cache_indices, Tensor lxu_cache_state, int invalid_index = -1, bool gather_cache_stats=False, Tensor(a!)? uvm_cache_stats=None) -> Tensor");
DISPATCH_TO_CUDA("lxu_cache_lookup", lxu_cache_lookup_cuda);
m.def(
"direct_mapped_lxu_cache_lookup(Tensor linear_cache_indices, Tensor lxu_cache_state, int invalid_index = -1) -> Tensor");
"direct_mapped_lxu_cache_lookup(Tensor linear_cache_indices, Tensor lxu_cache_state, int invalid_index = -1, bool gather_cache_stats=False, Tensor(a!)? uvm_cache_stats=None) -> Tensor");
DISPATCH_TO_CUDA(
"direct_mapped_lxu_cache_lookup", direct_mapped_lxu_cache_lookup_cuda);
m.def(
Expand Down