Skip to content

Commit

Permalink
Optimize the cache fetch for forward split, pt. 1 (pytorch#2218)
Browse files Browse the repository at this point in the history
Summary:


Rewrite the kernel to use cache_hit_rate enum as template argument.  We first check if the cache is empty and pass that value as a template argument.  Inside the first kernel, we then determine the cache conflict miss rate, and use this value to as a template parameter when invoking the second kernel, which performs the actual lookup work.

We pass in uvm_cache_stats as a run-time argument here instead of passing the cache miss rate as a compile-time argument, because uvm_cache_stats data is only available on the GPU, and incoking a templatized kernel with the cache miss rate as a template argument requires the cache misse information to first be passed back to the host, which is an expensive operation.

This is based on the earlier work in stacks D48937380 and D49675672, which have been based on very outdated branches of fbcode.

Differential Revision: D51865590
  • Loading branch information
q10 authored and facebook-github-bot committed Dec 18, 2023
1 parent a535f22 commit b2c138f
Show file tree
Hide file tree
Showing 10 changed files with 241 additions and 100 deletions.
6 changes: 3 additions & 3 deletions fbgemm_gpu/codegen/embedding_backward_split_grad_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ void split_embedding_backward_count_unique_indices_kernel
{% endfor %}

{% for vbe in [True, False] %}
{% set vbe_desc = "_vbe" if vbe else "" %}
{% set vdesc = "_vbe" if vbe else "" %}
template <typename grad_t>
__global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vbe_desc }}_kernel(
__global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vdesc }}_kernel(
pta::PackedTensorAccessor64<grad_t, 2, at::RestrictPtrTraits>
grad_output_mean,
const pta::PackedTensorAccessor64<grad_t, 2, at::RestrictPtrTraits>
Expand Down Expand Up @@ -205,7 +205,7 @@ __global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vbe_desc }}_kernel(

{% for grad_type in ['at::Half', 'float'] %}
template __global__ __launch_bounds__(kMaxThreads)
void grad_mean{{ vbe_desc }}_kernel
void grad_mean{{ vdesc }}_kernel
<{{ grad_type }}> (
pta::PackedTensorAccessor64<{{ grad_type }}, 2, at::RestrictPtrTraits>
grad_output_mean,
Expand Down
10 changes: 10 additions & 0 deletions fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Tensor split_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_cuda(
const Tensor& indice_weights,
{%- endif %}
const Tensor& lxu_cache_locations,
const Tensor& uvm_cache_stats,
const int64_t output_dtype,
{%- if vbe %}
const Tensor& vbe_row_output_offsets,
Expand Down Expand Up @@ -282,6 +283,13 @@ class {{ autograd_func }} :
const auto& flatten_dev_weights = dev_weights;
{%- endif %}




const auto uvm_cache_stats = at::empty({0}, uvm_weights.options().dtype(at::kInt));



{%- if not nobag %}
{%- for weighted in [False, True] %}
{%- set wdesc = "weighted" if weighted else "unweighted" %}
Expand Down Expand Up @@ -316,6 +324,7 @@ class {{ autograd_func }} :
*indice_weights,
{%- endif %}
lxu_cache_locations,
uvm_cache_stats,
output_dtype,
{%- if vbe %}
vbe_row_output_offsets,
Expand Down Expand Up @@ -346,6 +355,7 @@ class {{ autograd_func }} :
indices,
offsets,
lxu_cache_locations,
uvm_cache_stats,
output_dtype,
/*is_experimental=*/false
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

// clang-format off
{%- set wdesc = "weighted" if weighted else "unweighted" %}
{%- set vbe_desc = "_vbe" if vbe else "" %}
{%- set ndesc = "_nobag" if nobag else "" %}
{%- set vdesc = "_vbe" if vbe else "" %}

#include "fbgemm_gpu/embedding_backward_template_helpers.cuh"
#include "fbgemm_gpu/fbgemm_tensor_accessor.h"
#include "fbgemm_gpu/split_embeddings_utils.cuh"
Expand All @@ -33,7 +35,7 @@ __global__ __launch_bounds__(kMaxThreads) void
{%- if is_index_select %}
batch_index_select_dim0_codegen_backward_kernel_cta_per_row(
{%- else %}
split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_cta_per_row_1(
split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_cta_per_row_1(
{%- endif %}
const pta::PackedTensorAccessor64<grad_t, {{ "1" if is_index_select else "2" }}, at::RestrictPtrTraits> grad_output,
{%- if optimizer != "none" %}
Expand Down Expand Up @@ -427,7 +429,7 @@ template __global__ __launch_bounds__(kMaxThreads) void
{%- if is_index_select %}
batch_index_select_dim0_codegen_backward_kernel_cta_per_row
{%- else %}
split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_cta_per_row_1
split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_cta_per_row_1
{%- endif %}
< {{ emb_type }},
{{ grad_type }},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

// clang-format off
{%- set wdesc = "weighted" if weighted else "unweighted" %}
{%- set vbe_desc = "_vbe" if vbe else "" %}
{%- set ndesc = "_nobag" if nobag else "" %}
{%- set vdesc = "_vbe" if vbe else "" %}

#include "fbgemm_gpu/embedding_backward_template_helpers.cuh"
#include "fbgemm_gpu/fbgemm_tensor_accessor.h"
#include "fbgemm_gpu/split_embeddings_utils.cuh"
Expand All @@ -33,7 +35,7 @@ __global__ __launch_bounds__(kBackwardMaxThreads) void
{%- if is_index_select %}
batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
{%- else %}
split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_warp_per_row_1(
split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1(
{%- endif %}
const pta::PackedTensorAccessor64<grad_t, {{ "1" if is_index_select else "2" }}, at::RestrictPtrTraits> grad_output,
{%- if optimizer != "none" %}
Expand Down Expand Up @@ -270,7 +272,7 @@ template __global__ __launch_bounds__(kBackwardMaxThreads) void
{%- if is_index_select %}
batch_index_select_dim0_codegen_backward_kernel_warp_per_row
{%- else %}
split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_warp_per_row_1
split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1
{%- endif %}
< {{ emb_type }},
{{ grad_type }},
Expand Down
Loading

0 comments on commit b2c138f

Please sign in to comment.