diff --git a/fbgemm_gpu/codegen/embedding_backward_split_grad_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_grad_template.cu index 1db787760c..8adb1ea4a7 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_grad_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_grad_template.cu @@ -131,9 +131,9 @@ void split_embedding_backward_count_unique_indices_kernel {% endfor %} {% for vbe in [True, False] %} -{% set vbe_desc = "_vbe" if vbe else "" %} +{% set vdesc = "_vbe" if vbe else "" %} template -__global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vbe_desc }}_kernel( +__global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vdesc }}_kernel( pta::PackedTensorAccessor64 grad_output_mean, const pta::PackedTensorAccessor64 @@ -205,7 +205,7 @@ __global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vbe_desc }}_kernel( {% for grad_type in ['at::Half', 'float'] %} template __global__ __launch_bounds__(kMaxThreads) -void grad_mean{{ vbe_desc }}_kernel +void grad_mean{{ vdesc }}_kernel <{{ grad_type }}> ( pta::PackedTensorAccessor64<{{ grad_type }}, 2, at::RestrictPtrTraits> grad_output_mean, diff --git a/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp index ef82da3f84..cef1f8e432 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp @@ -53,6 +53,7 @@ Tensor split_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_cuda( const Tensor& indice_weights, {%- endif %} const Tensor& lxu_cache_locations, + const Tensor& uvm_cache_stats, const int64_t output_dtype, {%- if vbe %} const Tensor& vbe_row_output_offsets, @@ -282,6 +283,13 @@ class {{ autograd_func }} : const auto& flatten_dev_weights = dev_weights; {%- endif %} + + + + const auto uvm_cache_stats = at::empty({0}, uvm_weights.options().dtype(at::kInt)); + + + {%- if not nobag %} {%- for weighted in [False, True] %} {%- set wdesc = "weighted" if weighted else "unweighted" %} @@ -316,6 +324,7 @@ class {{ autograd_func }} : *indice_weights, {%- endif %} lxu_cache_locations, + uvm_cache_stats, output_dtype, {%- if vbe %} vbe_row_output_offsets, @@ -346,6 +355,7 @@ class {{ autograd_func }} : indices, offsets, lxu_cache_locations, + uvm_cache_stats, output_dtype, /*is_experimental=*/false ) diff --git a/fbgemm_gpu/codegen/embedding_backward_split_kernel_cta_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_kernel_cta_template.cu index 1cd572a33c..e53cceb106 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_kernel_cta_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_kernel_cta_template.cu @@ -8,7 +8,9 @@ // clang-format off {%- set wdesc = "weighted" if weighted else "unweighted" %} -{%- set vbe_desc = "_vbe" if vbe else "" %} +{%- set ndesc = "_nobag" if nobag else "" %} +{%- set vdesc = "_vbe" if vbe else "" %} + #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/fbgemm_tensor_accessor.h" #include "fbgemm_gpu/split_embeddings_utils.cuh" @@ -33,7 +35,7 @@ __global__ __launch_bounds__(kMaxThreads) void {%- if is_index_select %} batch_index_select_dim0_codegen_backward_kernel_cta_per_row( {%- else %} -split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_cta_per_row_1( +split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_cta_per_row_1( {%- endif %} const pta::PackedTensorAccessor64 grad_output, {%- if optimizer != "none" %} @@ -427,7 +429,7 @@ template __global__ __launch_bounds__(kMaxThreads) void {%- if is_index_select %} batch_index_select_dim0_codegen_backward_kernel_cta_per_row {%- else %} -split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_cta_per_row_1 +split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_cta_per_row_1 {%- endif %} < {{ emb_type }}, {{ grad_type }}, diff --git a/fbgemm_gpu/codegen/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_kernel_warp_template.cu index 176f80bd9d..a5357f59fc 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_kernel_warp_template.cu @@ -8,7 +8,9 @@ // clang-format off {%- set wdesc = "weighted" if weighted else "unweighted" %} -{%- set vbe_desc = "_vbe" if vbe else "" %} +{%- set ndesc = "_nobag" if nobag else "" %} +{%- set vdesc = "_vbe" if vbe else "" %} + #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/fbgemm_tensor_accessor.h" #include "fbgemm_gpu/split_embeddings_utils.cuh" @@ -33,7 +35,7 @@ __global__ __launch_bounds__(kBackwardMaxThreads) void {%- if is_index_select %} batch_index_select_dim0_codegen_backward_kernel_warp_per_row( {%- else %} -split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_warp_per_row_1( +split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1( {%- endif %} const pta::PackedTensorAccessor64 grad_output, {%- if optimizer != "none" %} @@ -270,7 +272,7 @@ template __global__ __launch_bounds__(kBackwardMaxThreads) void {%- if is_index_select %} batch_index_select_dim0_codegen_backward_kernel_warp_per_row {%- else %} -split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_warp_per_row_1 +split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1 {%- endif %} < {{ emb_type }}, {{ grad_type }}, diff --git a/fbgemm_gpu/codegen/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/embedding_forward_split_kernel_template.cu index 7f93a786ed..ac1c596f7e 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/embedding_forward_split_kernel_template.cu @@ -15,9 +15,13 @@ // See https://fburl.com/dw9ljh4h #} -{%- set wdesc = "weighted" if weighted else "unweighted" %} -{%- set vbe_desc = "_vbe" if vbe else "" %} +{%- set ddesc = "dense" if dense else "split" %} +{%- set wdesc = "weighted" if weighted else "unweighted" %} +{%- set ndesc = "_nobag" if nobag else "" %} +{%- set vdesc = "_vbe" if vbe else "" %} + #include "codegen/embedding_forward_template_helpers.cuh" +#include "fbgemm_gpu/split_embeddings_cache_cuda.cuh" using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -51,13 +55,30 @@ using namespace fbgemm_gpu; */#} {%- macro load_and_accumulate(from_cache) %} {#-/* Set the weights row */#} - const auto weights_row = WeightRow( + const auto weights_row = WeightRowAccessor + < + emb_t, + cache_t, + cache_t, + {%- if from_cache %} + true + {%- else %} + false + {%- endif %} + >( + {%- if from_cache %} + // Pass nullptr to avoid calling &weights[idx_j * D_emb], which loads + // memory into the registers as a side effect + nullptr, + {%- else %} + // Load from the embedding table const_cast(&weights[idx_j * D_emb]), + {%- endif %} {%- if from_cache %} // Load from the cache const_cast(&lxu_cache_weights[cache_idx_j][0]), {%- else %} - // Load from the embedding table + // Pass nullptr bc we are loading from the embedding table nullptr, {%- endif %} D, @@ -110,67 +131,15 @@ using namespace fbgemm_gpu; {%- endmacro %} -template < - typename emb_t, - typename cache_t, - typename output_t, - {%- if not dense %} - bool use_lxu_cache, - {%- endif %} - typename index_t, - {%- if not nobag %} - size_t kMaxVecsPerThread, - {%- endif %} - size_t kThreadGroupSize > -__launch_bounds__(kForwardMaxThreads) __global__ void -{%- if is_index_select %} -batch_index_select_dim0_codegen_forward_kernel( -{%- else %} -{{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}{{ vbe_desc }}_kernel( -{%- endif %} - const pta::PackedTensorAccessor64 dev_weights, - {%- if not dense %} - const pta::PackedTensorAccessor64 uvm_weights, - const pta::PackedTensorAccessor64 lxu_cache_weights, - const pta::PackedTensorAccessor32 weights_placements, - {%- endif %} - const pta::PackedTensorAccessor32 weights_offsets, - {%- if not nobag or is_index_select %} - const pta::PackedTensorAccessor32 D_offsets, - {%- else %} - int64_t D, - {%- endif %} // if nobag - {%- if vbe %} - const pta::PackedTensorAccessor32 row_output_offsets, - const pta::PackedTensorAccessor32 b_t_map, - const int32_t info_B_num_bits, - const uint32_t info_B_mask, - {%- else %} - FixedDivisor fd_B, - {%- endif %} - const pta::PackedTensorAccessor32 indices, - {%- if not is_index_select %} - const pta::PackedTensorAccessor32 offsets, - {%- endif %} - {%- if not nobag %} - int64_t pooling_mode, - {%- endif %} - {%- if weighted %} - pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> indice_weights, - {%- endif %} - {%- if not dense %} - const pta::PackedTensorAccessor32 lxu_cache_locations, - {%- endif %} - {%- if is_index_select %} - const pta::PackedTensorAccessor32 output_offsets, - const pta::PackedTensorAccessor32 total_L_offsets, - const int32_t fixed_L_per_warp, - const bool permute_output_dim_0_1, - {%- endif %} - // If 2D, shape is [B][total_D] - pta::PackedTensorAccessor64 output - ) { +{#-/* + This code chunk contains the implementation body of the kernel, and is + defined as a Jinja macro to be copy-pasted directly into the kernel as + opposed to a template impl function called by the kernel, because during + benchmarks, it was found that the extra function-calling resulted in a + 20-100 GB/s bandwidth reduction. +*/#} +{%- macro embedding_forward_kernel_impl_body(lxu_miss_rate) %} // shfl_sync_mask is implicitly used by SHFL_SYNC #ifdef FBGEMM_USE_SUBWARP_SHUFFLE const unsigned int shfl_sync_mask = @@ -277,10 +246,13 @@ batch_index_select_dim0_codegen_forward_kernel( for (int32_t l_start = 0; l_start < L; l_start += kThreadGroupSize) { // Determine the L index that this thread will load data from in cooperative load int32_t l = l_start + threadIdx.x; + + {%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %} // Cooperatively load the indices - int64_t idx = l < L ? indices[indices_start + l] : 0; + [[maybe_unused]] int64_t idx = l < L ? indices[indices_start + l] : 0; + {%- endif %} - {%- if not dense %} + {%- if not dense and lxu_miss_rate != "cache_conflict_miss_rate::all" %} // Cooperatively load the cache's indices [[maybe_unused]] int32_t cache_idx = (use_lxu_cache && placement == PlacementType::MANAGED_CACHING && l < L) ? lxu_cache_locations[indices_start + l] : 0; {%- endif %} @@ -292,8 +264,10 @@ batch_index_select_dim0_codegen_forward_kernel( // Iterate over kThreadGroupSize indices for (auto j = 0; j < kThreadGroupSize && l_start + j < L; ++j) { + {%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %} // Load index from thread j in the group - int64_t idx_j = SHFL_SYNC(idx, j); + [[maybe_unused]] int64_t idx_j = SHFL_SYNC(idx, j); + {%- endif %} {%- if is_index_select %} int64_t output_j = L_start + l_start + j; @@ -301,7 +275,7 @@ batch_index_select_dim0_codegen_forward_kernel( int64_t output_j = indices_start + l_start + j; {%- endif %} - {%- if not dense %} + {%- if not dense and lxu_miss_rate != "cache_conflict_miss_rate::all" %} // Load cache's index from thread j in the group [[maybe_unused]] int32_t cache_idx_j = use_lxu_cache ? SHFL_SYNC(cache_idx, j) : 0; {%- endif %} @@ -314,29 +288,33 @@ batch_index_select_dim0_codegen_forward_kernel( {#/**************************************************************/#} {#-/* - This is the main switch that determines how we are to load and accumulate - weights, and is determined by Jinja-time, compile-time, and run-time - variables. + This is the main switch that determines how we are to load and + accumulate weights, and is determined by Jinja-time, compile-time, + and run-time variables. */#} - {%- if dense %} {#-/* If it's dense, cache is not supported, so load from the embedding table */#} + {%- if dense %} + {#-/* If it's dense, cache is not supported, so load from the embedding table */#} + {{- load_and_accumulate(false) }} + + {%- elif lxu_miss_rate == "cache_conflict_miss_rate::all" %} + {#-/* Else if we know we have a 100% miss rate, then always fetch from the embedding table */#} {{- load_and_accumulate(false) }} - {%- else %} {#-/* Else, cache is supported, so now defer to compile-time selection */#} - if constexpr (use_lxu_cache) { - {#-/* If the row is available in the cache, fetch from the cache */#} + {%- elif lxu_miss_rate == "cache_conflict_miss_rate::zero" %} + {#-/* Else if we know we have a 0% miss rate, then always fetch from the cache */#} + {{- load_and_accumulate(true) }} + + {%- else %} + {#-/* Else we defer to run-time selection */#} if (placement == PlacementType::MANAGED_CACHING && cache_idx_j != kCacheLocationMissing) { + {#-/* If the row is available in the cache, fetch from the cache */#} {{ load_and_accumulate(true) }} - - {#-/* Else fetch from the embedding table */#} } else { + {#-/* Else fetch from the embedding table */#} {{ load_and_accumulate(false) }} } - } else { - {#-/* If we're not using the LXU cache, fetch from the embedding table */#} - {{- load_and_accumulate(false) }} - } {%- endif %} {#/**************************************************************/#} } @@ -399,6 +377,100 @@ batch_index_select_dim0_codegen_forward_kernel( } {%- endif %} +{%- endmacro %} + + +template < + typename emb_t, + typename cache_t, + typename output_t, + {%- if not dense %} + bool use_lxu_cache, + {%- endif %} + typename index_t, + {%- if not nobag %} + size_t kMaxVecsPerThread, + {%- endif %} + size_t kThreadGroupSize > +__launch_bounds__(kForwardMaxThreads) __global__ void +{%- if is_index_select %} +batch_index_select_dim0_codegen_forward_kernel( +{%- else %} +{{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_kernel( +{%- endif %} + const pta::PackedTensorAccessor64 dev_weights, + {%- if not dense %} + const pta::PackedTensorAccessor64 uvm_weights, + const pta::PackedTensorAccessor64 lxu_cache_weights, + const pta::PackedTensorAccessor32 weights_placements, + {%- endif %} + const pta::PackedTensorAccessor32 weights_offsets, + {%- if not nobag or is_index_select %} + const pta::PackedTensorAccessor32 D_offsets, + {%- else %} + int64_t D, + {%- endif %} // if nobag + {%- if vbe %} + const pta::PackedTensorAccessor32 row_output_offsets, + const pta::PackedTensorAccessor32 b_t_map, + const int32_t info_B_num_bits, + const uint32_t info_B_mask, + {%- else %} + FixedDivisor fd_B, + {%- endif %} + const pta::PackedTensorAccessor32 indices, + {%- if not is_index_select %} + const pta::PackedTensorAccessor32 offsets, + {%- endif %} + {%- if not nobag %} + int64_t pooling_mode, + {%- endif %} + {%- if weighted %} + pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> indice_weights, + {%- endif %} + {%- if not dense %} + const pta::PackedTensorAccessor32 lxu_cache_locations, + /* + NOTE: We pass in `lxu_cache_conflict_misses = + uvm_cache_stats[uvm_cache_stats_index::num_conflict_unique_misses]` as a + run-time argument here instead of passing the cache miss rate as a + compile-time argument, because `lxu_cache_conflict_misses` is only + available on the GPU, and invoking a templatized kernel with the cache + miss rate as a template argument requires this information to first be + passed back to the host, which is an expensive operation. + */ + const int32_t* lxu_cache_conflict_misses, + {%- endif %} + {%- if is_index_select %} + const pta::PackedTensorAccessor32 output_offsets, + const pta::PackedTensorAccessor32 total_L_offsets, + const int32_t fixed_L_per_warp, + const bool permute_output_dim_0_1, + {%- endif %} + // If 2D, shape is [B][total_D] + pta::PackedTensorAccessor64 output + ) { + {%- if dense %} + {{ embedding_forward_kernel_impl_body("NULL") }} + + {%- else %} + if constexpr (! use_lxu_cache) { + // If use_lxu_cache is false, then the cache conflict miss rate is + // effectively 100% + {{ embedding_forward_kernel_impl_body("cache_conflict_miss_rate::all") }} + + } else { + if (lxu_cache_conflict_misses && *lxu_cache_conflict_misses == 0) { + // If the UVM cache stats tensor is valid and tell us there are no + // conflict unique misses, then the miss rate is effectively 0% + {{ embedding_forward_kernel_impl_body("cache_conflict_miss_rate::zero") }} + + } else { + // Else, the cache conflict miss rate is mixed + {{ embedding_forward_kernel_impl_body("cache_conflict_miss_rate::mixed") }} + } + } + {%- endif %} } @@ -417,7 +489,7 @@ template __launch_bounds__(kForwardMaxThreads) __global__ void {%- if is_index_select %} batch_index_select_dim0_codegen_forward_kernel {%- else %} -{{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}{{ vbe_desc }}_kernel +{{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_kernel {%- endif %} < {{ emb_type }}, @@ -428,7 +500,7 @@ batch_index_select_dim0_codegen_forward_kernel {%- endif %} int64_t, {%- if not nobag %} - {{- kMaxVecsPerThread }}, + {{ kMaxVecsPerThread }}, {%- endif %} {{ kThreadGroupSize }} > ( @@ -464,6 +536,7 @@ batch_index_select_dim0_codegen_forward_kernel {%- endif %} {%- if not dense %} const pta::PackedTensorAccessor32 lxu_cache_locations, + const int32_t* lxu_cache_conflict_misses, {%- endif %} {%- if is_index_select %} const pta::PackedTensorAccessor32 output_offsets, diff --git a/fbgemm_gpu/codegen/embedding_forward_split_meta_template.cpp b/fbgemm_gpu/codegen/embedding_forward_split_meta_template.cpp index f9067f0a88..d78471c476 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_meta_template.cpp +++ b/fbgemm_gpu/codegen/embedding_forward_split_meta_template.cpp @@ -72,6 +72,7 @@ Tensor {%- endif %} {%- if not dense %} const Tensor& lxu_cache_locations, + const Tensor& uvm_cache_stats, {%- endif %} const int64_t output_dtype, {%- if vbe %} diff --git a/fbgemm_gpu/codegen/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/embedding_forward_split_template.cu index f572aa75e1..40fa31db99 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_forward_split_template.cu @@ -15,8 +15,8 @@ // See https://fburl.com/dw9ljh4h #} -{%- set ddesc = "dense" if dense else "split" %} -{%- set wdesc = "weighted" if weighted else "unweighted" %} +{%- set ddesc = "dense" if dense else "split" %} +{%- set wdesc = "weighted" if weighted else "unweighted" %} {%- set vdesc = "_vbe" if vbe else "" %} {%- if not dense and not nobag and not vbe %} @@ -30,6 +30,7 @@ //////////////////////////////////////////////////////////////////////////////// {%- endif %} #include "codegen/embedding_forward_template_helpers.cuh" +#include "fbgemm_gpu/split_embeddings_cache_cuda.cuh" using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -170,6 +171,7 @@ batch_index_select_dim0_codegen_forward_kernel( {%- endif %} {%- if not dense %} const pta::PackedTensorAccessor32 lxu_cache_locations, + const int32_t* lxu_cache_conflict_misses, {%- endif %} {%- if is_index_select %} const pta::PackedTensorAccessor32 output_offsets, @@ -248,7 +250,7 @@ batch_index_select_dim0_codegen_forward_kernel( {%- else %} {%- for use_cache in ["false", "true"] %} if (CACHE_CASE_ == {{ use_cache }}) { \ - constexpr auto _TUseCache = {{ use_cache }}; \ + constexpr auto use_cache_t = {{ use_cache }}; \ return __VA_ARGS__(); \ } \ {%- endfor %} @@ -302,6 +304,7 @@ batch_index_select_dim0_codegen_forward_cuda( {%- endif %} {%- if not dense %} const Tensor& lxu_cache_locations, + const Tensor& uvm_cache_stats, {%- endif %} const int64_t output_dtype, {%- if is_index_select %} @@ -547,7 +550,7 @@ batch_index_select_dim0_codegen_forward_cuda( {%- if dense or is_index_select %} {%- else %} - + {%- endif %} <<< div_round_up(total_B, kForwardMaxThreads / kWarpSize), @@ -574,6 +577,7 @@ batch_index_select_dim0_codegen_forward_cuda( {%- endif %} {%- if not dense %} MAKE_PTA_WITH_NAME(func_name, lxu_cache_locations, int32_t, 1, 32), + uvm_cache_stats.size(0) == 0 ? nullptr : (uvm_cache_stats.data_ptr() + uvm_cache_stats_index::num_conflict_unique_misses), {%- endif %} {%- if is_index_select %} MAKE_PTA_WITH_NAME(func_name, output_offsets, int64_t, 1, 32), @@ -598,9 +602,9 @@ batch_index_select_dim0_codegen_forward_cuda( const auto func_name = "{{ ddesc }}_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_kernel"; #endif {%- if dense %} - {{ ddesc }}_embedding_codegen_forward_{{ wdesc }}_kernel + {{ ddesc }}_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_kernel {%- else %} - {{ ddesc }}_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_kernel + {{ ddesc }}_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_kernel {%- endif %} <<< div_round_up(total_B, kForwardMaxThreads / kThreadGroupSize), @@ -632,6 +636,7 @@ batch_index_select_dim0_codegen_forward_cuda( {%- endif %} {%- if not dense %} MAKE_PTA_WITH_NAME(func_name, lxu_cache_locations, int32_t, 1, 32), + uvm_cache_stats.size(0) == 0 ? nullptr : (uvm_cache_stats.data_ptr() + uvm_cache_stats_index::num_conflict_unique_misses), {%- endif %} // if not dense MAKE_PTA_WITH_NAME(func_name, output, output_t, 2, 64) ); @@ -733,6 +738,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { {%- endif %} {%- if not dense %} " Tensor lxu_cache_locations, " + " Tensor uvm_cache_stats, " {%- endif %} " int output_dtype, " {%- if vbe %} diff --git a/fbgemm_gpu/codegen/embedding_forward_template_helpers.cuh b/fbgemm_gpu/codegen/embedding_forward_template_helpers.cuh index 1257011254..20af453f7a 100644 --- a/fbgemm_gpu/codegen/embedding_forward_template_helpers.cuh +++ b/fbgemm_gpu/codegen/embedding_forward_template_helpers.cuh @@ -42,6 +42,21 @@ constexpr int32_t kCacheLocationMissing = -1; constexpr size_t kForwardMaxThreads = 512; +namespace fbgemm_gpu { + +enum cache_conflict_miss_rate { + // Cache conflict misses will sometimes occur + mixed = 0, + // Cache conflict misses will always occur, i.e. every weight row to be + // accessed is NOT in the cache + all = 1, + // Cache conflict misses will never occur, i.e. every weight row to be + // accessed IS in the cache + zero = 2, +}; + +} // namespace fbgemm_gpu + namespace nbit { // "Effective" number of elements in the row when we include the row-wise // quantization parameters. diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 36bf5f43f9..75f3ee6c40 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -943,7 +943,9 @@ def forward( # noqa: C901 if len(self.timesteps_prefetched) == 0: self._prefetch(indices, offsets) - self.timesteps_prefetched.pop(0) + if len(self.timesteps_prefetched) > 0: + self.timesteps_prefetched.pop(0) + self.lxu_cache_locations = ( self.lxu_cache_locations_empty if len(self.lxu_cache_locations_list) == 0 @@ -1121,7 +1123,7 @@ def print_uvm_cache_stats(self) -> None: assert ( self.gather_uvm_cache_stats ), "gather_uvm_cache_stats should be set to true to access uvm cache stats." - uvm_cache_stats = self.uvm_cache_stats.tolist() + uvm_cache_stats: List[float] = self.uvm_cache_stats.tolist() logging.info( f"N_called: {uvm_cache_stats[0]}\n" f"N_requested_indices: {uvm_cache_stats[1]}\n" diff --git a/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh b/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh index 41fef90327..0361a23a59 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh @@ -1457,6 +1457,36 @@ struct WeightRow { } }; +template +struct WeightRowAccessor { + emb_t* row_; + cache_t* cache_row_; + int dim_; + + DEVICE_INLINE WeightRowAccessor( + emb_t* row, + cache_t* cache_row, + int dim, + StochasticRoundingRNGState* stoc_rounding_state) + : row_(row), cache_row_(cache_row), dim_(dim) {} + + DEVICE_INLINE Vec4T load(const int32_t d, const float2 qparams) const { + if constexpr (uses_cache) { + return dequantize_load(cache_row_ + d, qparams); + } else { + return dequantize_load(row_ + d, qparams); + } + } + + DEVICE_INLINE float2 load_qparams() const { + if constexpr (std::is_same_v) { + return load_qparams_from_row(row_ + dim_); + } else { + return make_float2(0.0f, 0.0f); + } + } +}; + __host__ DEVICE_INLINE int32_t div_round_up(int32_t a, int32_t b) { return (a + b - 1) / b; }