From 76d093c080ac37fc7b619e4d2c85c5eb816e8125 Mon Sep 17 00:00:00 2001 From: xiaoruichao Date: Fri, 11 Aug 2023 10:48:21 -0700 Subject: [PATCH] parallel reorder_batched_ads_indices/lengths Summary: add parellelism for reoder batched ads indices/lengths Meanwhile added a benchmark for such reoder_batched_ad_lengths Differential Revision: https://internalfb.com/D48082346 fbshipit-source-id: b299b3481172916ad05b0c635fe52f7b2e90d16d --- fbgemm_gpu/bench/sparse_ops_benchmark.py | 90 +++++++++++++-- fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp | 113 +++++++++++-------- 2 files changed, 145 insertions(+), 58 deletions(-) diff --git a/fbgemm_gpu/bench/sparse_ops_benchmark.py b/fbgemm_gpu/bench/sparse_ops_benchmark.py index f517e2f661..161bc64242 100644 --- a/fbgemm_gpu/bench/sparse_ops_benchmark.py +++ b/fbgemm_gpu/bench/sparse_ops_benchmark.py @@ -365,6 +365,7 @@ def asynchronous_complete_cumsum_2d_ref(lengths: torch.Tensor) -> torch.Tensor: @click.option("--dtype", type=click.Choice(["float", "long"]), default="long") @click.option("--itype", type=click.Choice(["int", "long"]), default="int") @click.option("--broadcast-indices", type=bool, default=True) +@click.option("--device", type=str, default="cpu") def reorder_batched_ad_indices_bench( batch_size: int, table_size: int, @@ -373,7 +374,9 @@ def reorder_batched_ad_indices_bench( dtype: str, itype: str, broadcast_indices: bool, + device: str, ) -> None: + assert dtype == "float" or dtype == "long", "Only int and long are supported" data_type = torch.int64 if dtype == "long" else torch.float data_size = 8 if dtype == "long" else 4 @@ -389,7 +392,7 @@ def reorder_batched_ad_indices_bench( size=(batch_size * table_size * length,), ) .int() - .cuda() + .to(device) .to(data_type) ) cat_ad_lengths = ( @@ -401,7 +404,7 @@ def reorder_batched_ad_indices_bench( 0, ) .int() - .cuda() + .to(device) ) else: cat_ad_indices = ( @@ -411,7 +414,7 @@ def reorder_batched_ad_indices_bench( size=(batch_size * table_size * num_ads * length,), ) .int() - .cuda() + .to(device) .to(data_type) ) cat_ad_lengths = ( @@ -423,23 +426,27 @@ def reorder_batched_ad_indices_bench( 0, ) .int() - .cuda() + .to(device) ) batch_offsets = ( torch.tensor([num_ads * b for b in range(batch_size + 1)]).int().cuda() - ) + ).to(device) num_ads_in_batch = batch_size * num_ads reordered_cat_ad_lengths = torch.ops.fbgemm.reorder_batched_ad_lengths( cat_ad_lengths, batch_offsets, num_ads_in_batch, broadcast_indices - ) + ).to(device) - cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(cat_ad_lengths).to( - index_type + cat_ad_offsets = ( + torch.ops.fbgemm.asynchronous_complete_cumsum(cat_ad_lengths) + .to(index_type) + .to(device) + ) + reordered_cat_ad_offsets = ( + torch.ops.fbgemm.asynchronous_complete_cumsum(reordered_cat_ad_lengths) + .to(index_type) + .to(device) ) - reordered_cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( - reordered_cat_ad_lengths - ).to(index_type) time, _ = benchmark_torch_function( torch.ops.fbgemm.reorder_batched_ad_indices, ( @@ -460,6 +467,67 @@ def reorder_batched_ad_indices_bench( ) +@cli.command() +@click.option("--batch-size", default=8192) +@click.option("--table-size", default=20) +@click.option("--length", default=50) +@click.option("--num-ads", default=100) +@click.option("--broadcast-indices", type=bool, default=True) +@click.option("--device", type=str, default="cpu") +def reorder_batched_ad_lengths_bench( + batch_size: int, + table_size: int, + length: int, + num_ads: int, + broadcast_indices: bool, + device: str, +) -> None: + if broadcast_indices: + cat_ad_lengths = ( + torch.cat( + [ + torch.tensor([length for _ in range(table_size)]) + for _ in range(batch_size) + ], + 0, + ) + .int() + .to(device) + ) + else: + cat_ad_lengths = ( + torch.cat( + [ + torch.tensor([length for _ in range(table_size * num_ads)]) + for _ in range(batch_size) + ], + 0, + ) + .int() + .to(device) + ) + + batch_offsets = ( + torch.tensor([num_ads * b for b in range(batch_size + 1)]).int().cuda() + ).to(device) + num_ads_in_batch = batch_size * num_ads + time, _ = benchmark_torch_function( + torch.ops.fbgemm.reorder_batched_ad_lengths, + ( + cat_ad_lengths, + batch_offsets, + num_ads_in_batch, + broadcast_indices, + ), + num_warmups=100, + iters=1000, + ) + num_bytes = batch_size * table_size * (num_ads + 1) * length * 4 + logging.info( + f"fbgemm_gpu time: {time * 1000:.5f} ms ({num_bytes / time / 1e9:.5f} GB/s)" + ) + + @cli.command() @click.option("--num-inputs", default=1024) @click.option("--rows", default=100) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp index 66b5568f92..6a038c6965 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp @@ -1152,21 +1152,30 @@ void reorder_batched_ad_lengths_( const auto* batch_offsets_data = batch_offsets.data_ptr(); const auto* cat_ad_lengths_data = cat_ad_lengths.data_ptr(); auto* output_data = output.data_ptr(); - for (const auto b : c10::irange(nB)) { - const auto num_ads_b = batch_offsets_data[b + 1] - batch_offsets_data[b]; - for (const auto t : c10::irange(nT)) { - const int32_t input_segment_start = broadcast_lengths - ? nT * b + t - : nT * batch_offsets_data[b] + t * num_ads_b; - const int32_t output_segment_start = - t * num_ads_in_batch + batch_offsets_data[b]; - for (const auto i : c10::irange(num_ads_b)) { - output_data[output_segment_start + i] = broadcast_lengths - ? cat_ad_lengths_data[input_segment_start] - : cat_ad_lengths_data[input_segment_start + i]; - } - } - } + at::parallel_for( + 0, nB * nT, FALSE_SHARING_PAD, [&](int64_t tb_begin, int64_t tb_end) { + auto b_begin = tb_begin / nT; + auto b_end = (tb_end + nT - 1) / nT; + for (const auto b : c10::irange(b_begin, b_end)) { + const auto num_ads_b = + batch_offsets_data[b + 1] - batch_offsets_data[b]; + int64_t t_begin = (b == b_begin) ? tb_begin % nT : 0; + int64_t t_end = + (b == b_end - 1 && tb_end % nT != 0) ? tb_end % nT : nT; + for (const auto t : c10::irange(t_begin, t_end)) { + const int32_t input_segment_start = broadcast_lengths + ? nT * b + t + : nT * batch_offsets_data[b] + t * num_ads_b; + const int32_t output_segment_start = + t * num_ads_in_batch + batch_offsets_data[b]; + for (const auto i : c10::irange(num_ads_b)) { + output_data[output_segment_start + i] = broadcast_lengths + ? cat_ad_lengths_data[input_segment_start] + : cat_ad_lengths_data[input_segment_start + i]; + } + } + } + }); } Tensor reorder_batched_ad_lengths_cpu( @@ -1221,40 +1230,50 @@ void reorder_batched_ad_indices_cpu_( reordered_cat_ad_offsets.data_ptr(); const auto* cat_ad_indices_data = cat_ad_indices.data_ptr(); auto* output_data = output.data_ptr(); - for (const auto b : c10::irange(nB)) { - const auto num_ads_b = batch_offsets_data[b + 1] - batch_offsets_data[b]; - for (const auto t : c10::irange(nT)) { - const auto output_segment_offset_start = - t * num_ads_in_batch + batch_offsets_data[b]; - const auto output_segment_start = - reordered_cat_ad_offsets_data[output_segment_offset_start]; - const int32_t input_segment_offset_start = broadcast_indices - ? nT * b + t - : nT * batch_offsets_data[b] + t * num_ads_b; - const int32_t input_segment_offset_end = broadcast_indices - ? input_segment_offset_start + 1 - : input_segment_offset_start + num_ads_b; - const auto input_segment_start = - cat_ad_offsets_data[input_segment_offset_start]; - const auto input_segment_end = - cat_ad_offsets_data[input_segment_offset_end]; - const auto num_elements = input_segment_end - input_segment_start; - - if (broadcast_indices) { - for (auto j : c10::irange(num_ads_b)) { - for (auto i : c10::irange(num_elements)) { - output_data[output_segment_start + j * num_elements + i] = - cat_ad_indices_data[input_segment_start + i]; + at::parallel_for( + 0, nB * nT, FALSE_SHARING_PAD, [&](int64_t tb_begin, int64_t tb_end) { + auto b_begin = tb_begin / nT; + auto b_end = (tb_end + nT - 1) / nT; + + for (const auto b : c10::irange(b_begin, b_end)) { + const auto num_ads_b = + batch_offsets_data[b + 1] - batch_offsets_data[b]; + int64_t t_begin = (b == b_begin) ? tb_begin % nT : 0; + int64_t t_end = + (b == b_end - 1 && tb_end % nT != 0) ? tb_end % nT : nT; + for (const auto t : c10::irange(t_begin, t_end)) { + const auto output_segment_offset_start = + t * num_ads_in_batch + batch_offsets_data[b]; + const auto output_segment_start = + reordered_cat_ad_offsets_data[output_segment_offset_start]; + const int32_t input_segment_offset_start = broadcast_indices + ? nT * b + t + : nT * batch_offsets_data[b] + t * num_ads_b; + const int32_t input_segment_offset_end = broadcast_indices + ? input_segment_offset_start + 1 + : input_segment_offset_start + num_ads_b; + const auto input_segment_start = + cat_ad_offsets_data[input_segment_offset_start]; + const auto input_segment_end = + cat_ad_offsets_data[input_segment_offset_end]; + const auto num_elements = input_segment_end - input_segment_start; + + if (broadcast_indices) { + for (auto j : c10::irange(num_ads_b)) { + for (auto i : c10::irange(num_elements)) { + output_data[output_segment_start + j * num_elements + i] = + cat_ad_indices_data[input_segment_start + i]; + } + } + } else { + for (auto i : c10::irange(num_elements)) { + output_data[output_segment_start + i] = + cat_ad_indices_data[input_segment_start + i]; + } + } } } - } else { - for (auto i : c10::irange(num_elements)) { - output_data[output_segment_start + i] = - cat_ad_indices_data[input_segment_start + i]; - } - } - } - } + }); } Tensor reorder_batched_ad_indices_cpu(