Skip to content

Commit

Permalink
uvm_cache_stats for direct mapped (pytorch#1951)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1951

D40518654 introduced `uvm_cache_stats` to provide cache metrics for FBGEMM 32way cache.
This diff expands its usage to also provide cache metrics for direct mapped cache.

Differential Revision: D48023956

fbshipit-source-id: 09846b188c3d65814f8b291e8a7b659835b14c73
  • Loading branch information
SungMinCho authored and facebook-github-bot committed Aug 18, 2023
1 parent 016a27b commit 0e3a4e3
Show file tree
Hide file tree
Showing 6 changed files with 263 additions and 20 deletions.
33 changes: 22 additions & 11 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -1727,6 +1727,7 @@ def nbit_uvm(
@click.option("--fp8-exponent-bits", type=int, default=None)
@click.option("--fp8-exponent-bias", type=int, default=None)
@click.option("--record-cache", is_flag=True, default=False)
@click.option("--uvm-host-mapped", is_flag=True, default=False)
@click.option(
"--dump-requests", type=int, default=0, help="number of reqs to dump (0=no dump)"
)
Expand All @@ -1753,6 +1754,7 @@ def nbit_uvm_compare_direct_mapped(
fp8_exponent_bits: Optional[int],
fp8_exponent_bias: Optional[int],
record_cache: bool,
uvm_host_mapped: bool,
dump_requests: int,
) -> None:
logging.info(json.dumps({k: str(v) for k, v in locals().items()}, indent=2))
Expand Down Expand Up @@ -1837,18 +1839,21 @@ def bench_uvm_cls(
enforce_hbm=enforce_hbm,
fp8_exponent_bits=fp8_exponent_bits,
fp8_exponent_bias=fp8_exponent_bias,
record_cache_metrics=RecordCacheMetrics(record_cache, record_cache),
gather_uvm_cache_stats=record_cache,
uvm_host_mapped=uvm_host_mapped,
).cuda()
emb.fill_random_weights()

# label nvtx only when cache counter is off
nvtx_range = "" if record_cache else f"UVM-{name.upper()}"
callback_after_warmup = emb.reset_cache_miss_counter if record_cache else None
requests = requests_uvm[:1] if record_cache else requests_uvm
nvtx_range = (
f"UVM-RECORD-CACHE-{name.upper()}"
if record_cache
else f"UVM-{name.upper()}"
)
callback_after_warmup = emb.reset_uvm_cache_stats if record_cache else None

torch.cuda.cudart().cudaProfilerStart()
time_per_iter = benchmark_requests(
requests,
requests_uvm,
lambda indices, offsets, per_sample_weights: emb.forward(
indices.int(),
offsets.int(),
Expand Down Expand Up @@ -1881,12 +1886,14 @@ def bench_uvm_cls(
)

if record_cache:
cmc = emb.cache_miss_counter.detach().cpu().numpy().tolist()
cmc = emb.uvm_cache_stats.detach().cpu().numpy().tolist()
cache_stats = {
"miss_forward_count": cmc[0],
"unique_miss": cmc[1],
"unique_req": cmc[2],
"nondedup_req": cmc[3],
"num_calls": cmc[0],
"num_requested_indices": cmc[1],
"num_unique_indices": cmc[2],
"num_unique_misses": cmc[3],
"num_conflict_unique_misses": cmc[4],
"num_conflict_misses": cmc[5],
}
stats[name]["cache_stats"] = cache_stats
logging.info(f"[{name:>8s}] cache stats {cache_stats}")
Expand Down Expand Up @@ -1932,6 +1939,7 @@ def bench_uvm_cls(
@click.option("--batch-size", default=512)
@click.option("--cache-algorithm", default="lru")
@click.option("--cache-load-factor", default=0.2)
@click.option("--cache-assoc", default=32)
@click.option("--embedding-dim", default=128)
@click.option("--weights-precision", type=SparseType, default=SparseType.INT4)
@click.option("--iters", default=100)
Expand All @@ -1954,6 +1962,7 @@ def nbit_cache( # noqa C901
batch_size: int,
cache_algorithm: str,
cache_load_factor: float,
cache_assoc: int,
embedding_dim: int,
weights_precision: SparseType,
iters: int,
Expand Down Expand Up @@ -2003,6 +2012,7 @@ def nbit_cache( # noqa C901
enforce_hbm=enforce_hbm,
fp8_exponent_bits=fp8_exponent_bits,
fp8_exponent_bias=fp8_exponent_bias,
cache_assoc=cache_assoc,
).cuda()
emb_nc.fill_random_weights()

Expand All @@ -2027,6 +2037,7 @@ def nbit_cache( # noqa C901
enforce_hbm=enforce_hbm,
fp8_exponent_bits=fp8_exponent_bits,
fp8_exponent_bias=fp8_exponent_bias,
cache_assoc=cache_assoc,
).cuda()
emb.fill_random_weights()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,9 @@ def prefetch_1way(self, linear_cache_indices: Tensor) -> None:
self.timestep_counter.get(),
self.lxu_state,
self.lxu_cache_miss_timestamp,
16, # row_alignment; using default value.
self.gather_uvm_cache_stats,
self.local_uvm_cache_stats,
)
else:
raise ValueError("Direct Mapped for LRU only")
Expand All @@ -620,8 +623,18 @@ def prefetch_1way(self, linear_cache_indices: Tensor) -> None:
linear_cache_indices,
self.lxu_cache_state,
self.total_cache_hash_size,
self.gather_uvm_cache_stats,
self.local_uvm_cache_stats,
)
)
if self.gather_uvm_cache_stats:
# Accumulate local_uvm_cache_stats (int32) into uvm_cache_stats (int64).
# We may wanna do this accumulation atomically, but as it's only for monitoring,
# slightly inaccurate result may be acceptable.
self.uvm_cache_stats = torch.add(
self.uvm_cache_stats, self.local_uvm_cache_stats
)
self.local_uvm_cache_stats.zero_()

def _update_cache_miss_counter(
self,
Expand Down
8 changes: 6 additions & 2 deletions fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ void direct_mapped_lru_cache_populate_byte_cuda(
int64_t time_stamp,
at::Tensor lru_state,
at::Tensor lxu_cache_miss_timestamp,
int64_t row_alignment);
int64_t row_alignment,
bool gather_cache_stats,
c10::optional<at::Tensor> uvm_cache_stats);

///@ingroup table-batched-embed-cuda
/// LFU cache: fetch the rows corresponding to `linear_cache_indices` from
Expand Down Expand Up @@ -174,7 +176,9 @@ at::Tensor emulate_cache_miss(
at::Tensor direct_mapped_lxu_cache_lookup_cuda(
at::Tensor linear_cache_indices,
at::Tensor lxu_cache_state,
int64_t invalid_index);
int64_t invalid_index,
bool gather_cache_stats,
c10::optional<at::Tensor> uvm_cache_stats);

//////@ingroup table-batched-embed-cuda
/// Flush the cache: store the weights from the cache to the backing storage.
Expand Down
102 changes: 97 additions & 5 deletions fbgemm_gpu/src/split_embeddings_cache_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <cub/device/device_radix_sort.cuh>
#include <cub/device/device_run_length_encode.cuh>
#include <cub/device/device_select.cuh>
#include <cub/block/block_reduce.cuh>
#include "fbgemm_gpu/cub_namespace_postfix.cuh"
// clang-format on

Expand Down Expand Up @@ -728,11 +729,24 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lru_cache_find_uncached_kernel
lxu_cache_state,
const int64_t time_stamp,
at::PackedTensorAccessor32<int64_t, 2, at::RestrictPtrTraits> lru_state,
const bool gather_cache_stats,
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
uvm_cache_stats,
at::PackedTensorAccessor32<int64_t, 2, at::RestrictPtrTraits>
lxu_cache_miss_timestamp) {
const int32_t N = linear_cache_indices.size(0);
const int32_t C = lxu_cache_state.size(0);
if (gather_cache_stats) {
if (blockIdx.x == 0 && threadIdx.x == 0) {
atomicAdd(
&uvm_cache_stats[uvm_cache_stats_index::num_calls], 1); // N_called.
atomicAdd(
&uvm_cache_stats[uvm_cache_stats_index::num_requested_indices],
N); // N_requested_indices.
}
}
CUDA_KERNEL_LOOP(n, N) {
int64_t idx = linear_cache_indices[n];
if (idx == max_indices) {
Expand Down Expand Up @@ -879,7 +893,9 @@ Tensor direct_mapped_lru_cache_find_uncached_cuda(
Tensor lxu_cache_state,
int64_t time_stamp,
Tensor lru_state,
Tensor lxu_cache_miss_timestamp) {
Tensor lxu_cache_miss_timestamp,
bool gather_cache_stats,
Tensor uvm_cache_stats) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
linear_cache_indices,
lxu_cache_state,
Expand Down Expand Up @@ -914,6 +930,9 @@ Tensor direct_mapped_lru_cache_find_uncached_cuda(
.packed_accessor32<int64_t, 2, at::RestrictPtrTraits>(),
time_stamp,
lru_state.packed_accessor32<int64_t, 2, at::RestrictPtrTraits>(),
gather_cache_stats,
uvm_cache_stats
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
lxu_cache_miss_timestamp
.packed_accessor32<int64_t, 2, at::RestrictPtrTraits>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
Expand Down Expand Up @@ -1418,6 +1437,9 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lru_cache_insert_byte_kernel(
at::PackedTensorAccessor32<int64_t, 2, at::RestrictPtrTraits>
lxu_cache_miss_timestamp,
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> cache_sets,
const bool gather_cache_stats,
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
uvm_cache_stats,
const int64_t row_alignment) {
const int32_t N = cache_sets.size(0);
Expand Down Expand Up @@ -1445,6 +1467,24 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lru_cache_insert_byte_kernel(
// continue;
// }
if (gather_cache_stats && threadIdx.x == 0) {
// We are using this slot for a slightly different purpose.
// In 32 way:
// UVM traffic for insert
// = # of inserted rows
// = # of unique misses - # of unique misses that were not inserted
// = uvm_cache_stats_index::num_unique_misses
// - uvm_cache_stats_index::num_conflict_unique_misses
// In Direct Mapped (here):
// UVM traffic for insert
// = # of inserted rows
// = uvm_cache_stats_index::num_conflict_unique_misses
// (just store here directly)
atomicAdd(
&uvm_cache_stats[uvm_cache_stats_index::num_conflict_unique_misses],
1);
}
// insert the index in the buffer into our only slot
const int32_t insert_slot = 0;
Expand Down Expand Up @@ -1568,6 +1608,8 @@ void direct_mapped_lru_cache_insert_byte_cuda(
Tensor linear_cache_indices,
Tensor lxu_cache_miss_timestamp,
Tensor cache_sets,
bool gather_cache_stats,
Tensor uvm_cache_stats,
int64_t row_alignment) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
weights,
Expand Down Expand Up @@ -1618,6 +1660,9 @@ void direct_mapped_lru_cache_insert_byte_cuda(
lxu_cache_miss_timestamp
.packed_accessor32<int64_t, 2, at::RestrictPtrTraits>(),
cache_sets.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
gather_cache_stats,
uvm_cache_stats
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
row_alignment);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
Expand Down Expand Up @@ -1729,7 +1774,9 @@ DLL_PUBLIC void direct_mapped_lru_cache_populate_byte_cuda(
int64_t time_stamp,
Tensor lru_state,
Tensor lxu_cache_miss_timestamp,
int64_t row_alignment) {
int64_t row_alignment,
bool gather_cache_stats,
c10::optional<Tensor> uvm_cache_stats) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
weights,
cache_hash_size_cumsum,
Expand All @@ -1743,6 +1790,13 @@ DLL_PUBLIC void direct_mapped_lru_cache_populate_byte_cuda(
lru_state,
lxu_cache_miss_timestamp);
Tensor uvm_cache_stats_ = at::empty({0}, weights.options().dtype(at::kInt));
if (gather_cache_stats) {
TORCH_CHECK(uvm_cache_stats.has_value());
uvm_cache_stats_ = uvm_cache_stats.value();
TENSOR_ON_CUDA_GPU(uvm_cache_stats_);
}
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(weights.get_device());
Expand Down Expand Up @@ -1785,7 +1839,9 @@ DLL_PUBLIC void direct_mapped_lru_cache_populate_byte_cuda(
lxu_cache_state,
time_stamp,
lru_state,
lxu_cache_miss_timestamp);
lxu_cache_miss_timestamp,
gather_cache_stats,
uvm_cache_stats_);
// insert caching weights
direct_mapped_lru_cache_insert_byte_cuda(
Expand All @@ -1802,6 +1858,8 @@ DLL_PUBLIC void direct_mapped_lru_cache_populate_byte_cuda(
linear_cache_indices,
lxu_cache_miss_timestamp,
cache_sets,
gather_cache_stats,
uvm_cache_stats_,
row_alignment);
}
Expand Down Expand Up @@ -2617,10 +2675,16 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lxu_cache_lookup_kernel(
lxu_cache_state,
int64_t invalid_index,
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
lxu_cache_locations) {
lxu_cache_locations,
const bool gather_cache_stats,
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
uvm_cache_stats) {
const int32_t C = lxu_cache_state.size(0);
const int32_t N = linear_cache_indices.size(0);
int32_t n_indices = 0;
int32_t n_hits = 0;
CUDA_KERNEL_LOOP(n, N) {
int32_t cache_location = kCacheLocationMissing;
const auto slot = 0;
Expand All @@ -2631,13 +2695,29 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lxu_cache_lookup_kernel(
}
const int32_t cache_set = cache_slot(idx, C);
n_indices++;
const bool found =
(::__ldg((&lxu_cache_state[cache_set][0]) + slot) == idx);
if (found) {
cache_location = cache_set;
n_hits++;
}
lxu_cache_locations[n] = cache_location;
}
if (gather_cache_stats) {
typedef cub::BlockReduce<int32_t, kMaxThreads> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp;
int32_t conflict_miss = n_indices - n_hits;
int32_t conflict_miss_sum = BlockReduce(temp).Sum(conflict_miss);
if (threadIdx.x == 0) {
atomicAdd(
&uvm_cache_stats[uvm_cache_stats_index::num_conflict_misses],
conflict_miss_sum);
}
}
}
} // namespace
Expand Down Expand Up @@ -2748,10 +2828,19 @@ DLL_PUBLIC void lxu_cache_locations_update_cuda(
DLL_PUBLIC Tensor direct_mapped_lxu_cache_lookup_cuda(
Tensor linear_cache_indices,
Tensor lxu_cache_state,
int64_t invalid_index) {
int64_t invalid_index,
bool gather_cache_stats,
c10::optional<Tensor> uvm_cache_stats) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
linear_cache_indices, lxu_cache_state);
Tensor uvm_cache_stats_ =
at::empty({0}, linear_cache_indices.options().dtype(at::kInt));
if (gather_cache_stats) {
TORCH_CHECK(uvm_cache_stats.has_value());
uvm_cache_stats_ = uvm_cache_stats.value();
}
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(linear_cache_indices.get_device());
Expand Down Expand Up @@ -2780,6 +2869,9 @@ DLL_PUBLIC Tensor direct_mapped_lxu_cache_lookup_cuda(
.packed_accessor32<int64_t, 2, at::RestrictPtrTraits>(),
invalid_index,
lxu_cache_locations
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
gather_cache_stats,
uvm_cache_stats_
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
Expand Down
4 changes: 2 additions & 2 deletions fbgemm_gpu/src/split_table_batched_embeddings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"lru_cache_populate_byte(Tensor weights, Tensor hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, int time_stamp, Tensor(c!) lru_state, int row_alignment=16, bool gather_cache_stats=False, Tensor(d!)? uvm_cache_stats=None) -> ()");
DISPATCH_TO_CUDA("lru_cache_populate_byte", lru_cache_populate_byte_cuda);
m.def(
"direct_mapped_lru_cache_populate_byte(Tensor weights, Tensor hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, int time_stamp, Tensor(c!) lru_state, Tensor(d!) lxu_cache_miss_timestamp, int row_alignment=16) -> ()");
"direct_mapped_lru_cache_populate_byte(Tensor weights, Tensor hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, int time_stamp, Tensor(c!) lru_state, Tensor(d!) lxu_cache_miss_timestamp, int row_alignment=16, bool gather_cache_stats=False, Tensor(e!)? uvm_cache_stats=None) -> ()");
DISPATCH_TO_CUDA(
"direct_mapped_lru_cache_populate_byte",
direct_mapped_lru_cache_populate_byte_cuda);
Expand All @@ -45,7 +45,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"lxu_cache_lookup(Tensor linear_cache_indices, Tensor lxu_cache_state, int invalid_index = -1, bool gather_cache_stats=False, Tensor(a!)? uvm_cache_stats=None) -> Tensor");
DISPATCH_TO_CUDA("lxu_cache_lookup", lxu_cache_lookup_cuda);
m.def(
"direct_mapped_lxu_cache_lookup(Tensor linear_cache_indices, Tensor lxu_cache_state, int invalid_index = -1) -> Tensor");
"direct_mapped_lxu_cache_lookup(Tensor linear_cache_indices, Tensor lxu_cache_state, int invalid_index = -1, bool gather_cache_stats=False, Tensor(a!)? uvm_cache_stats=None) -> Tensor");
DISPATCH_TO_CUDA(
"direct_mapped_lxu_cache_lookup", direct_mapped_lxu_cache_lookup_cuda);
m.def(
Expand Down
Loading

0 comments on commit 0e3a4e3

Please sign in to comment.