From b2c138fd71d426b819033650c8f08633868133c7 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Mon, 18 Dec 2023 12:06:58 -0800 Subject: [PATCH] Optimize the cache fetch for forward split, pt. 1 (#2218) Summary: Rewrite the kernel to use cache_hit_rate enum as template argument. We first check if the cache is empty and pass that value as a template argument. Inside the first kernel, we then determine the cache conflict miss rate, and use this value to as a template parameter when invoking the second kernel, which performs the actual lookup work. We pass in uvm_cache_stats as a run-time argument here instead of passing the cache miss rate as a compile-time argument, because uvm_cache_stats data is only available on the GPU, and incoking a templatized kernel with the cache miss rate as a template argument requires the cache misse information to first be passed back to the host, which is an expensive operation. This is based on the earlier work in stacks D48937380 and D49675672, which have been based on very outdated branches of fbcode. Differential Revision: D51865590 --- .../embedding_backward_split_grad_template.cu | 6 +- ...embedding_backward_split_host_template.cpp | 10 + ...ding_backward_split_kernel_cta_template.cu | 8 +- ...ing_backward_split_kernel_warp_template.cu | 8 +- ...embedding_forward_split_kernel_template.cu | 239 ++++++++++++------ .../embedding_forward_split_meta_template.cpp | 1 + .../embedding_forward_split_template.cu | 18 +- .../embedding_forward_template_helpers.cuh | 15 ++ ...t_table_batched_embeddings_ops_training.py | 6 +- .../include/fbgemm_gpu/fbgemm_cuda_utils.cuh | 30 +++ 10 files changed, 241 insertions(+), 100 deletions(-) diff --git a/fbgemm_gpu/codegen/embedding_backward_split_grad_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_grad_template.cu index 1db787760c..8adb1ea4a7 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_grad_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_grad_template.cu @@ -131,9 +131,9 @@ void split_embedding_backward_count_unique_indices_kernel {% endfor %} {% for vbe in [True, False] %} -{% set vbe_desc = "_vbe" if vbe else "" %} +{% set vdesc = "_vbe" if vbe else "" %} template -__global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vbe_desc }}_kernel( +__global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vdesc }}_kernel( pta::PackedTensorAccessor64 grad_output_mean, const pta::PackedTensorAccessor64 @@ -205,7 +205,7 @@ __global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vbe_desc }}_kernel( {% for grad_type in ['at::Half', 'float'] %} template __global__ __launch_bounds__(kMaxThreads) -void grad_mean{{ vbe_desc }}_kernel +void grad_mean{{ vdesc }}_kernel <{{ grad_type }}> ( pta::PackedTensorAccessor64<{{ grad_type }}, 2, at::RestrictPtrTraits> grad_output_mean, diff --git a/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp index ef82da3f84..cef1f8e432 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp @@ -53,6 +53,7 @@ Tensor split_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_cuda( const Tensor& indice_weights, {%- endif %} const Tensor& lxu_cache_locations, + const Tensor& uvm_cache_stats, const int64_t output_dtype, {%- if vbe %} const Tensor& vbe_row_output_offsets, @@ -282,6 +283,13 @@ class {{ autograd_func }} : const auto& flatten_dev_weights = dev_weights; {%- endif %} + + + + const auto uvm_cache_stats = at::empty({0}, uvm_weights.options().dtype(at::kInt)); + + + {%- if not nobag %} {%- for weighted in [False, True] %} {%- set wdesc = "weighted" if weighted else "unweighted" %} @@ -316,6 +324,7 @@ class {{ autograd_func }} : *indice_weights, {%- endif %} lxu_cache_locations, + uvm_cache_stats, output_dtype, {%- if vbe %} vbe_row_output_offsets, @@ -346,6 +355,7 @@ class {{ autograd_func }} : indices, offsets, lxu_cache_locations, + uvm_cache_stats, output_dtype, /*is_experimental=*/false ) diff --git a/fbgemm_gpu/codegen/embedding_backward_split_kernel_cta_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_kernel_cta_template.cu index 1cd572a33c..e53cceb106 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_kernel_cta_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_kernel_cta_template.cu @@ -8,7 +8,9 @@ // clang-format off {%- set wdesc = "weighted" if weighted else "unweighted" %} -{%- set vbe_desc = "_vbe" if vbe else "" %} +{%- set ndesc = "_nobag" if nobag else "" %} +{%- set vdesc = "_vbe" if vbe else "" %} + #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/fbgemm_tensor_accessor.h" #include "fbgemm_gpu/split_embeddings_utils.cuh" @@ -33,7 +35,7 @@ __global__ __launch_bounds__(kMaxThreads) void {%- if is_index_select %} batch_index_select_dim0_codegen_backward_kernel_cta_per_row( {%- else %} -split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_cta_per_row_1( +split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_cta_per_row_1( {%- endif %} const pta::PackedTensorAccessor64 grad_output, {%- if optimizer != "none" %} @@ -427,7 +429,7 @@ template __global__ __launch_bounds__(kMaxThreads) void {%- if is_index_select %} batch_index_select_dim0_codegen_backward_kernel_cta_per_row {%- else %} -split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_cta_per_row_1 +split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_cta_per_row_1 {%- endif %} < {{ emb_type }}, {{ grad_type }}, diff --git a/fbgemm_gpu/codegen/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_kernel_warp_template.cu index 176f80bd9d..a5357f59fc 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_kernel_warp_template.cu @@ -8,7 +8,9 @@ // clang-format off {%- set wdesc = "weighted" if weighted else "unweighted" %} -{%- set vbe_desc = "_vbe" if vbe else "" %} +{%- set ndesc = "_nobag" if nobag else "" %} +{%- set vdesc = "_vbe" if vbe else "" %} + #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/fbgemm_tensor_accessor.h" #include "fbgemm_gpu/split_embeddings_utils.cuh" @@ -33,7 +35,7 @@ __global__ __launch_bounds__(kBackwardMaxThreads) void {%- if is_index_select %} batch_index_select_dim0_codegen_backward_kernel_warp_per_row( {%- else %} -split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_warp_per_row_1( +split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1( {%- endif %} const pta::PackedTensorAccessor64 grad_output, {%- if optimizer != "none" %} @@ -270,7 +272,7 @@ template __global__ __launch_bounds__(kBackwardMaxThreads) void {%- if is_index_select %} batch_index_select_dim0_codegen_backward_kernel_warp_per_row {%- else %} -split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_warp_per_row_1 +split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1 {%- endif %} < {{ emb_type }}, {{ grad_type }}, diff --git a/fbgemm_gpu/codegen/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/embedding_forward_split_kernel_template.cu index 7f93a786ed..ac1c596f7e 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/embedding_forward_split_kernel_template.cu @@ -15,9 +15,13 @@ // See https://fburl.com/dw9ljh4h #} -{%- set wdesc = "weighted" if weighted else "unweighted" %} -{%- set vbe_desc = "_vbe" if vbe else "" %} +{%- set ddesc = "dense" if dense else "split" %} +{%- set wdesc = "weighted" if weighted else "unweighted" %} +{%- set ndesc = "_nobag" if nobag else "" %} +{%- set vdesc = "_vbe" if vbe else "" %} + #include "codegen/embedding_forward_template_helpers.cuh" +#include "fbgemm_gpu/split_embeddings_cache_cuda.cuh" using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -51,13 +55,30 @@ using namespace fbgemm_gpu; */#} {%- macro load_and_accumulate(from_cache) %} {#-/* Set the weights row */#} - const auto weights_row = WeightRow( + const auto weights_row = WeightRowAccessor + < + emb_t, + cache_t, + cache_t, + {%- if from_cache %} + true + {%- else %} + false + {%- endif %} + >( + {%- if from_cache %} + // Pass nullptr to avoid calling &weights[idx_j * D_emb], which loads + // memory into the registers as a side effect + nullptr, + {%- else %} + // Load from the embedding table const_cast(&weights[idx_j * D_emb]), + {%- endif %} {%- if from_cache %} // Load from the cache const_cast(&lxu_cache_weights[cache_idx_j][0]), {%- else %} - // Load from the embedding table + // Pass nullptr bc we are loading from the embedding table nullptr, {%- endif %} D, @@ -110,67 +131,15 @@ using namespace fbgemm_gpu; {%- endmacro %} -template < - typename emb_t, - typename cache_t, - typename output_t, - {%- if not dense %} - bool use_lxu_cache, - {%- endif %} - typename index_t, - {%- if not nobag %} - size_t kMaxVecsPerThread, - {%- endif %} - size_t kThreadGroupSize > -__launch_bounds__(kForwardMaxThreads) __global__ void -{%- if is_index_select %} -batch_index_select_dim0_codegen_forward_kernel( -{%- else %} -{{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}{{ vbe_desc }}_kernel( -{%- endif %} - const pta::PackedTensorAccessor64 dev_weights, - {%- if not dense %} - const pta::PackedTensorAccessor64 uvm_weights, - const pta::PackedTensorAccessor64 lxu_cache_weights, - const pta::PackedTensorAccessor32 weights_placements, - {%- endif %} - const pta::PackedTensorAccessor32 weights_offsets, - {%- if not nobag or is_index_select %} - const pta::PackedTensorAccessor32 D_offsets, - {%- else %} - int64_t D, - {%- endif %} // if nobag - {%- if vbe %} - const pta::PackedTensorAccessor32 row_output_offsets, - const pta::PackedTensorAccessor32 b_t_map, - const int32_t info_B_num_bits, - const uint32_t info_B_mask, - {%- else %} - FixedDivisor fd_B, - {%- endif %} - const pta::PackedTensorAccessor32 indices, - {%- if not is_index_select %} - const pta::PackedTensorAccessor32 offsets, - {%- endif %} - {%- if not nobag %} - int64_t pooling_mode, - {%- endif %} - {%- if weighted %} - pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> indice_weights, - {%- endif %} - {%- if not dense %} - const pta::PackedTensorAccessor32 lxu_cache_locations, - {%- endif %} - {%- if is_index_select %} - const pta::PackedTensorAccessor32 output_offsets, - const pta::PackedTensorAccessor32 total_L_offsets, - const int32_t fixed_L_per_warp, - const bool permute_output_dim_0_1, - {%- endif %} - // If 2D, shape is [B][total_D] - pta::PackedTensorAccessor64 output - ) { +{#-/* + This code chunk contains the implementation body of the kernel, and is + defined as a Jinja macro to be copy-pasted directly into the kernel as + opposed to a template impl function called by the kernel, because during + benchmarks, it was found that the extra function-calling resulted in a + 20-100 GB/s bandwidth reduction. +*/#} +{%- macro embedding_forward_kernel_impl_body(lxu_miss_rate) %} // shfl_sync_mask is implicitly used by SHFL_SYNC #ifdef FBGEMM_USE_SUBWARP_SHUFFLE const unsigned int shfl_sync_mask = @@ -277,10 +246,13 @@ batch_index_select_dim0_codegen_forward_kernel( for (int32_t l_start = 0; l_start < L; l_start += kThreadGroupSize) { // Determine the L index that this thread will load data from in cooperative load int32_t l = l_start + threadIdx.x; + + {%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %} // Cooperatively load the indices - int64_t idx = l < L ? indices[indices_start + l] : 0; + [[maybe_unused]] int64_t idx = l < L ? indices[indices_start + l] : 0; + {%- endif %} - {%- if not dense %} + {%- if not dense and lxu_miss_rate != "cache_conflict_miss_rate::all" %} // Cooperatively load the cache's indices [[maybe_unused]] int32_t cache_idx = (use_lxu_cache && placement == PlacementType::MANAGED_CACHING && l < L) ? lxu_cache_locations[indices_start + l] : 0; {%- endif %} @@ -292,8 +264,10 @@ batch_index_select_dim0_codegen_forward_kernel( // Iterate over kThreadGroupSize indices for (auto j = 0; j < kThreadGroupSize && l_start + j < L; ++j) { + {%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %} // Load index from thread j in the group - int64_t idx_j = SHFL_SYNC(idx, j); + [[maybe_unused]] int64_t idx_j = SHFL_SYNC(idx, j); + {%- endif %} {%- if is_index_select %} int64_t output_j = L_start + l_start + j; @@ -301,7 +275,7 @@ batch_index_select_dim0_codegen_forward_kernel( int64_t output_j = indices_start + l_start + j; {%- endif %} - {%- if not dense %} + {%- if not dense and lxu_miss_rate != "cache_conflict_miss_rate::all" %} // Load cache's index from thread j in the group [[maybe_unused]] int32_t cache_idx_j = use_lxu_cache ? SHFL_SYNC(cache_idx, j) : 0; {%- endif %} @@ -314,29 +288,33 @@ batch_index_select_dim0_codegen_forward_kernel( {#/**************************************************************/#} {#-/* - This is the main switch that determines how we are to load and accumulate - weights, and is determined by Jinja-time, compile-time, and run-time - variables. + This is the main switch that determines how we are to load and + accumulate weights, and is determined by Jinja-time, compile-time, + and run-time variables. */#} - {%- if dense %} {#-/* If it's dense, cache is not supported, so load from the embedding table */#} + {%- if dense %} + {#-/* If it's dense, cache is not supported, so load from the embedding table */#} + {{- load_and_accumulate(false) }} + + {%- elif lxu_miss_rate == "cache_conflict_miss_rate::all" %} + {#-/* Else if we know we have a 100% miss rate, then always fetch from the embedding table */#} {{- load_and_accumulate(false) }} - {%- else %} {#-/* Else, cache is supported, so now defer to compile-time selection */#} - if constexpr (use_lxu_cache) { - {#-/* If the row is available in the cache, fetch from the cache */#} + {%- elif lxu_miss_rate == "cache_conflict_miss_rate::zero" %} + {#-/* Else if we know we have a 0% miss rate, then always fetch from the cache */#} + {{- load_and_accumulate(true) }} + + {%- else %} + {#-/* Else we defer to run-time selection */#} if (placement == PlacementType::MANAGED_CACHING && cache_idx_j != kCacheLocationMissing) { + {#-/* If the row is available in the cache, fetch from the cache */#} {{ load_and_accumulate(true) }} - - {#-/* Else fetch from the embedding table */#} } else { + {#-/* Else fetch from the embedding table */#} {{ load_and_accumulate(false) }} } - } else { - {#-/* If we're not using the LXU cache, fetch from the embedding table */#} - {{- load_and_accumulate(false) }} - } {%- endif %} {#/**************************************************************/#} } @@ -399,6 +377,100 @@ batch_index_select_dim0_codegen_forward_kernel( } {%- endif %} +{%- endmacro %} + + +template < + typename emb_t, + typename cache_t, + typename output_t, + {%- if not dense %} + bool use_lxu_cache, + {%- endif %} + typename index_t, + {%- if not nobag %} + size_t kMaxVecsPerThread, + {%- endif %} + size_t kThreadGroupSize > +__launch_bounds__(kForwardMaxThreads) __global__ void +{%- if is_index_select %} +batch_index_select_dim0_codegen_forward_kernel( +{%- else %} +{{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_kernel( +{%- endif %} + const pta::PackedTensorAccessor64 dev_weights, + {%- if not dense %} + const pta::PackedTensorAccessor64 uvm_weights, + const pta::PackedTensorAccessor64 lxu_cache_weights, + const pta::PackedTensorAccessor32 weights_placements, + {%- endif %} + const pta::PackedTensorAccessor32 weights_offsets, + {%- if not nobag or is_index_select %} + const pta::PackedTensorAccessor32 D_offsets, + {%- else %} + int64_t D, + {%- endif %} // if nobag + {%- if vbe %} + const pta::PackedTensorAccessor32 row_output_offsets, + const pta::PackedTensorAccessor32 b_t_map, + const int32_t info_B_num_bits, + const uint32_t info_B_mask, + {%- else %} + FixedDivisor fd_B, + {%- endif %} + const pta::PackedTensorAccessor32 indices, + {%- if not is_index_select %} + const pta::PackedTensorAccessor32 offsets, + {%- endif %} + {%- if not nobag %} + int64_t pooling_mode, + {%- endif %} + {%- if weighted %} + pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> indice_weights, + {%- endif %} + {%- if not dense %} + const pta::PackedTensorAccessor32 lxu_cache_locations, + /* + NOTE: We pass in `lxu_cache_conflict_misses = + uvm_cache_stats[uvm_cache_stats_index::num_conflict_unique_misses]` as a + run-time argument here instead of passing the cache miss rate as a + compile-time argument, because `lxu_cache_conflict_misses` is only + available on the GPU, and invoking a templatized kernel with the cache + miss rate as a template argument requires this information to first be + passed back to the host, which is an expensive operation. + */ + const int32_t* lxu_cache_conflict_misses, + {%- endif %} + {%- if is_index_select %} + const pta::PackedTensorAccessor32 output_offsets, + const pta::PackedTensorAccessor32 total_L_offsets, + const int32_t fixed_L_per_warp, + const bool permute_output_dim_0_1, + {%- endif %} + // If 2D, shape is [B][total_D] + pta::PackedTensorAccessor64 output + ) { + {%- if dense %} + {{ embedding_forward_kernel_impl_body("NULL") }} + + {%- else %} + if constexpr (! use_lxu_cache) { + // If use_lxu_cache is false, then the cache conflict miss rate is + // effectively 100% + {{ embedding_forward_kernel_impl_body("cache_conflict_miss_rate::all") }} + + } else { + if (lxu_cache_conflict_misses && *lxu_cache_conflict_misses == 0) { + // If the UVM cache stats tensor is valid and tell us there are no + // conflict unique misses, then the miss rate is effectively 0% + {{ embedding_forward_kernel_impl_body("cache_conflict_miss_rate::zero") }} + + } else { + // Else, the cache conflict miss rate is mixed + {{ embedding_forward_kernel_impl_body("cache_conflict_miss_rate::mixed") }} + } + } + {%- endif %} } @@ -417,7 +489,7 @@ template __launch_bounds__(kForwardMaxThreads) __global__ void {%- if is_index_select %} batch_index_select_dim0_codegen_forward_kernel {%- else %} -{{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}{{ vbe_desc }}_kernel +{{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_kernel {%- endif %} < {{ emb_type }}, @@ -428,7 +500,7 @@ batch_index_select_dim0_codegen_forward_kernel {%- endif %} int64_t, {%- if not nobag %} - {{- kMaxVecsPerThread }}, + {{ kMaxVecsPerThread }}, {%- endif %} {{ kThreadGroupSize }} > ( @@ -464,6 +536,7 @@ batch_index_select_dim0_codegen_forward_kernel {%- endif %} {%- if not dense %} const pta::PackedTensorAccessor32 lxu_cache_locations, + const int32_t* lxu_cache_conflict_misses, {%- endif %} {%- if is_index_select %} const pta::PackedTensorAccessor32 output_offsets, diff --git a/fbgemm_gpu/codegen/embedding_forward_split_meta_template.cpp b/fbgemm_gpu/codegen/embedding_forward_split_meta_template.cpp index f9067f0a88..d78471c476 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_meta_template.cpp +++ b/fbgemm_gpu/codegen/embedding_forward_split_meta_template.cpp @@ -72,6 +72,7 @@ Tensor {%- endif %} {%- if not dense %} const Tensor& lxu_cache_locations, + const Tensor& uvm_cache_stats, {%- endif %} const int64_t output_dtype, {%- if vbe %} diff --git a/fbgemm_gpu/codegen/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/embedding_forward_split_template.cu index f572aa75e1..40fa31db99 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_forward_split_template.cu @@ -15,8 +15,8 @@ // See https://fburl.com/dw9ljh4h #} -{%- set ddesc = "dense" if dense else "split" %} -{%- set wdesc = "weighted" if weighted else "unweighted" %} +{%- set ddesc = "dense" if dense else "split" %} +{%- set wdesc = "weighted" if weighted else "unweighted" %} {%- set vdesc = "_vbe" if vbe else "" %} {%- if not dense and not nobag and not vbe %} @@ -30,6 +30,7 @@ //////////////////////////////////////////////////////////////////////////////// {%- endif %} #include "codegen/embedding_forward_template_helpers.cuh" +#include "fbgemm_gpu/split_embeddings_cache_cuda.cuh" using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -170,6 +171,7 @@ batch_index_select_dim0_codegen_forward_kernel( {%- endif %} {%- if not dense %} const pta::PackedTensorAccessor32 lxu_cache_locations, + const int32_t* lxu_cache_conflict_misses, {%- endif %} {%- if is_index_select %} const pta::PackedTensorAccessor32 output_offsets, @@ -248,7 +250,7 @@ batch_index_select_dim0_codegen_forward_kernel( {%- else %} {%- for use_cache in ["false", "true"] %} if (CACHE_CASE_ == {{ use_cache }}) { \ - constexpr auto _TUseCache = {{ use_cache }}; \ + constexpr auto use_cache_t = {{ use_cache }}; \ return __VA_ARGS__(); \ } \ {%- endfor %} @@ -302,6 +304,7 @@ batch_index_select_dim0_codegen_forward_cuda( {%- endif %} {%- if not dense %} const Tensor& lxu_cache_locations, + const Tensor& uvm_cache_stats, {%- endif %} const int64_t output_dtype, {%- if is_index_select %} @@ -547,7 +550,7 @@ batch_index_select_dim0_codegen_forward_cuda( {%- if dense or is_index_select %} {%- else %} - + {%- endif %} <<< div_round_up(total_B, kForwardMaxThreads / kWarpSize), @@ -574,6 +577,7 @@ batch_index_select_dim0_codegen_forward_cuda( {%- endif %} {%- if not dense %} MAKE_PTA_WITH_NAME(func_name, lxu_cache_locations, int32_t, 1, 32), + uvm_cache_stats.size(0) == 0 ? nullptr : (uvm_cache_stats.data_ptr() + uvm_cache_stats_index::num_conflict_unique_misses), {%- endif %} {%- if is_index_select %} MAKE_PTA_WITH_NAME(func_name, output_offsets, int64_t, 1, 32), @@ -598,9 +602,9 @@ batch_index_select_dim0_codegen_forward_cuda( const auto func_name = "{{ ddesc }}_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_kernel"; #endif {%- if dense %} - {{ ddesc }}_embedding_codegen_forward_{{ wdesc }}_kernel + {{ ddesc }}_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_kernel {%- else %} - {{ ddesc }}_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_kernel + {{ ddesc }}_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_kernel {%- endif %} <<< div_round_up(total_B, kForwardMaxThreads / kThreadGroupSize), @@ -632,6 +636,7 @@ batch_index_select_dim0_codegen_forward_cuda( {%- endif %} {%- if not dense %} MAKE_PTA_WITH_NAME(func_name, lxu_cache_locations, int32_t, 1, 32), + uvm_cache_stats.size(0) == 0 ? nullptr : (uvm_cache_stats.data_ptr() + uvm_cache_stats_index::num_conflict_unique_misses), {%- endif %} // if not dense MAKE_PTA_WITH_NAME(func_name, output, output_t, 2, 64) ); @@ -733,6 +738,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { {%- endif %} {%- if not dense %} " Tensor lxu_cache_locations, " + " Tensor uvm_cache_stats, " {%- endif %} " int output_dtype, " {%- if vbe %} diff --git a/fbgemm_gpu/codegen/embedding_forward_template_helpers.cuh b/fbgemm_gpu/codegen/embedding_forward_template_helpers.cuh index 1257011254..20af453f7a 100644 --- a/fbgemm_gpu/codegen/embedding_forward_template_helpers.cuh +++ b/fbgemm_gpu/codegen/embedding_forward_template_helpers.cuh @@ -42,6 +42,21 @@ constexpr int32_t kCacheLocationMissing = -1; constexpr size_t kForwardMaxThreads = 512; +namespace fbgemm_gpu { + +enum cache_conflict_miss_rate { + // Cache conflict misses will sometimes occur + mixed = 0, + // Cache conflict misses will always occur, i.e. every weight row to be + // accessed is NOT in the cache + all = 1, + // Cache conflict misses will never occur, i.e. every weight row to be + // accessed IS in the cache + zero = 2, +}; + +} // namespace fbgemm_gpu + namespace nbit { // "Effective" number of elements in the row when we include the row-wise // quantization parameters. diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 36bf5f43f9..75f3ee6c40 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -943,7 +943,9 @@ def forward( # noqa: C901 if len(self.timesteps_prefetched) == 0: self._prefetch(indices, offsets) - self.timesteps_prefetched.pop(0) + if len(self.timesteps_prefetched) > 0: + self.timesteps_prefetched.pop(0) + self.lxu_cache_locations = ( self.lxu_cache_locations_empty if len(self.lxu_cache_locations_list) == 0 @@ -1121,7 +1123,7 @@ def print_uvm_cache_stats(self) -> None: assert ( self.gather_uvm_cache_stats ), "gather_uvm_cache_stats should be set to true to access uvm cache stats." - uvm_cache_stats = self.uvm_cache_stats.tolist() + uvm_cache_stats: List[float] = self.uvm_cache_stats.tolist() logging.info( f"N_called: {uvm_cache_stats[0]}\n" f"N_requested_indices: {uvm_cache_stats[1]}\n" diff --git a/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh b/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh index 41fef90327..0361a23a59 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh @@ -1457,6 +1457,36 @@ struct WeightRow { } }; +template +struct WeightRowAccessor { + emb_t* row_; + cache_t* cache_row_; + int dim_; + + DEVICE_INLINE WeightRowAccessor( + emb_t* row, + cache_t* cache_row, + int dim, + StochasticRoundingRNGState* stoc_rounding_state) + : row_(row), cache_row_(cache_row), dim_(dim) {} + + DEVICE_INLINE Vec4T load(const int32_t d, const float2 qparams) const { + if constexpr (uses_cache) { + return dequantize_load(cache_row_ + d, qparams); + } else { + return dequantize_load(row_ + d, qparams); + } + } + + DEVICE_INLINE float2 load_qparams() const { + if constexpr (std::is_same_v) { + return load_qparams_from_row(row_ + dim_); + } else { + return make_float2(0.0f, 0.0f); + } + } +}; + __host__ DEVICE_INLINE int32_t div_round_up(int32_t a, int32_t b) { return (a + b - 1) / b; }