diff --git a/fbgemm_gpu/bench/sparse_ops_benchmark.py b/fbgemm_gpu/bench/sparse_ops_benchmark.py index 161bc64242..4ade42d3df 100644 --- a/fbgemm_gpu/bench/sparse_ops_benchmark.py +++ b/fbgemm_gpu/bench/sparse_ops_benchmark.py @@ -679,5 +679,193 @@ def batch_group_index_select_bwd( ) +@cli.command() +@click.option("--batch-size", default=8192) +@click.option("--table-size", default=20) +@click.option("--length", default=50) +@click.option("--num-ads", default=100) +@click.option("--dtype", type=click.Choice(["float", "long"]), default="long") +@click.option("--itype", type=click.Choice(["int", "long"]), default="int") +@click.option("--broadcast-indices", type=bool, default=True) +def cat_reorder_batched_ad_indices_bench( + batch_size: int, + table_size: int, + length: int, + num_ads: int, + dtype: str, + itype: str, + broadcast_indices: bool, +) -> None: + assert dtype == "float" or dtype == "long", "Only int and long are supported" + data_type = torch.int64 if dtype == "long" else torch.float + data_size = 8 if dtype == "long" else 4 + + assert itype == "int" or itype == "long", "Only int and long are supported" + + if broadcast_indices: + ad_indices = [ + ( + torch.randint( + low=0, + high=100, + size=(table_size * length,), + ) + .int() + .to(data_type) + ) + for _ in range(batch_size) + ] + ad_lengths = [ + torch.tensor([length for _ in range(table_size)]).int() + for _ in range(batch_size) + ] + else: + ad_indices = [ + ( + torch.randint( + low=0, + high=100, + size=(table_size * num_ads * length,), + ) + .int() + .to(data_type) + ) + for _ in range(batch_size) + ] + ad_lengths = [ + torch.tensor([length for _ in range(table_size * num_ads)]).int() + for _ in range(batch_size) + ] + + batch_offsets = torch.tensor([num_ads * b for b in range(batch_size + 1)]).int() + num_ads_in_batch = batch_size * num_ads + + # pyre-ignore + def pass_1(ad_indices, ad_lengths, batch_offsets, num_ads_in_batch): + cat_ad_lengths = torch.cat(ad_lengths, 0).to("cuda", non_blocking=True) + cat_ad_indices = torch.cat(ad_indices, 0).to("cuda", non_blocking=True) + batch_offsets = batch_offsets.to("cuda", non_blocking=True) + reordered_cat_ad_lengths = torch.ops.fbgemm.reorder_batched_ad_lengths( + cat_ad_lengths, batch_offsets, num_ads_in_batch, broadcast_indices + ) + cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(cat_ad_lengths) + reordered_cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + reordered_cat_ad_lengths + ) + reordered_cat_ad_indices = torch.ops.fbgemm.reorder_batched_ad_indices( + cat_ad_offsets, + cat_ad_indices, + reordered_cat_ad_offsets, + batch_offsets, + num_ads_in_batch, + broadcast_indices, + batch_size * table_size * num_ads * length, + ) + + return reordered_cat_ad_indices, reordered_cat_ad_lengths + + # process length on device and process indice on device + # pyre-ignore + def pass_2(ad_indices, ad_lengths, batch_offsets, num_ads_in_batch): + cat_ad_lengths = torch.cat(ad_lengths, 0) + + reordered_cat_ad_lengths = torch.ops.fbgemm.reorder_batched_ad_lengths( + cat_ad_lengths, batch_offsets, num_ads_in_batch, broadcast_indices + ) + cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(cat_ad_lengths) + reordered_cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + reordered_cat_ad_lengths + ) + cat_ad_indices = torch.cat(ad_indices, 0) + + reordered_cat_ad_indices = torch.ops.fbgemm.reorder_batched_ad_indices( + cat_ad_offsets.to("cuda", non_blocking=True), + cat_ad_indices.to("cuda", non_blocking=True), + reordered_cat_ad_offsets.to("cuda", non_blocking=True), + batch_offsets.to("cuda", non_blocking=True), + num_ads_in_batch, + broadcast_indices, + batch_size * table_size * num_ads * length, + ) + + return reordered_cat_ad_indices, reordered_cat_ad_lengths.to( + "cuda", non_blocking=True + ) + + # minimize GPU workload + unfused cat + reorder + # pyre-ignore + def pass_3(ad_indices, ad_lengths, batch_offsets, num_ads_in_batch): + cat_ad_lengths = torch.cat(ad_lengths, 0) + reordered_cat_ad_lengths = torch.ops.fbgemm.reorder_batched_ad_lengths( + cat_ad_lengths, batch_offsets, num_ads_in_batch, broadcast_indices + ) + + cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(cat_ad_lengths) + reordered_cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + reordered_cat_ad_lengths + ) + cat_ad_indices = torch.cat(ad_indices, 0) + + reordered_cat_ad_indices = torch.ops.fbgemm.reorder_batched_ad_indices( + cat_ad_offsets, + cat_ad_indices, + reordered_cat_ad_offsets, + batch_offsets, + num_ads_in_batch, + broadcast_indices, + batch_size * table_size * num_ads * length, + ) + + return reordered_cat_ad_indices.to( + "cuda", non_blocking=True + ), reordered_cat_ad_lengths.to("cuda", non_blocking=True) + + # minimize GPU workload + fuse cat + reorder + # pyre-ignore + def pass_4(ad_indices, ad_lengths, batch_offsets, num_ads_in_batch): + cat_ad_lengths = torch.cat(ad_lengths, 0) + reordered_cat_ad_lengths = torch.ops.fbgemm.reorder_batched_ad_lengths( + cat_ad_lengths, batch_offsets, num_ads_in_batch, broadcast_indices + ) + + cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(cat_ad_lengths) + reordered_cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + reordered_cat_ad_lengths + ) + + reordered_cat_ad_indices = torch.ops.fbgemm.cat_reorder_batched_ad_indices( + cat_ad_offsets, + ad_indices, + reordered_cat_ad_offsets, + batch_offsets, + num_ads_in_batch, + broadcast_indices, + batch_size * table_size * num_ads * length, + ) + + return reordered_cat_ad_indices.to( + "cuda", non_blocking=True + ), reordered_cat_ad_lengths.to("cuda", non_blocking=True) + + num_bytes = batch_size * table_size * (num_ads + 1) * length * data_size + + # pyre-ignore + def ben(fn, name, ad_indices, ad_lengths, batch_offsets, num_ads_in_batch): + time, _ = benchmark_torch_function( + fn, + (ad_indices, ad_lengths, batch_offsets, num_ads_in_batch), + num_warmups=50, + iters=500, + ) + logging.info( + f"{name} fbgemm_gpu time: {time * 1000:.5f} ms ({num_bytes / time / 1e9:.5f} GB/s)" + ) + + ben(pass_1, "pass_1", ad_indices, ad_lengths, batch_offsets, num_ads_in_batch) + ben(pass_2, "pass_2", ad_indices, ad_lengths, batch_offsets, num_ads_in_batch) + ben(pass_3, "pass_3", ad_indices, ad_lengths, batch_offsets, num_ads_in_batch) + ben(pass_4, "pass_4", ad_indices, ad_lengths, batch_offsets, num_ads_in_batch) + + if __name__ == "__main__": cli() diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h index 91efc035cc..2c77dbe664 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h @@ -352,7 +352,16 @@ at::Tensor reorder_batched_ad_indices_cpu( const int64_t num_ads_in_batch, const bool broadcast_indices = false, const int64_t num_indices_after_broadcast = -1); - +///@ingroup sparse-data-cpu +at::Tensor cat_reorder_batched_ad_indices_cpu( + const at::Tensor& cat_ad_offsets, + const std::vector& cat_ad_indices, + const at::Tensor& reordered_cat_ad_offsets, + const at::Tensor& batch_offsets, + const int64_t num_ads_in_batch, + const bool broadcast_indices, + const int64_t num_indices_after_broadcast, + const bool pinned_memory = false); at::Tensor recat_embedding_grad_output_cuda( at::Tensor grad_output, // [B_local][T_global][D] const std::vector& num_features_per_rank); diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp index 6a038c6965..bdc4be63de 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp @@ -1276,6 +1276,72 @@ void reorder_batched_ad_indices_cpu_( }); } +template +void cat_reorder_batched_ad_indices_cpu_( + const Tensor& cat_ad_offsets, + const std::vector& ad_indices, + const Tensor& reordered_cat_ad_offsets, + const Tensor& batch_offsets, + const int64_t num_ads_in_batch, + const bool broadcast_indices, + Tensor& output) { + const int64_t nB = batch_offsets.numel() - 1; + const int64_t nT = (reordered_cat_ad_offsets.numel() - 1) / num_ads_in_batch; + + const auto* batch_offsets_data = batch_offsets.data_ptr(); + const auto* cat_ad_offsets_data = cat_ad_offsets.data_ptr(); + const auto* reordered_cat_ad_offsets_data = + reordered_cat_ad_offsets.data_ptr(); + auto* output_data = output.data_ptr(); + at::parallel_for( + 0, nB * nT, FALSE_SHARING_PAD, [&](int64_t tb_begin, int64_t tb_end) { + auto b_begin = tb_begin / nT; + auto b_end = (tb_end + nT - 1) / nT; + for (auto b : c10::irange(b_begin, b_end)) { + const auto* ad_indices_data = ad_indices[b].data_ptr(); + const auto num_ads_b = + batch_offsets_data[b + 1] - batch_offsets_data[b]; + int64_t t_begin = (b == b_begin) ? tb_begin % nT : 0; + int64_t t_end = + (b == b_end - 1 && tb_end % nT != 0) ? tb_end % nT : nT; + for (auto t : c10::irange(t_begin, t_end)) { + const auto output_segment_offset_start = + t * num_ads_in_batch + batch_offsets_data[b]; + const auto output_segment_start = + reordered_cat_ad_offsets_data[output_segment_offset_start]; + const int32_t input_segment_offset_start = broadcast_indices + ? nT * b + t + : nT * batch_offsets_data[b] + t * num_ads_b; + const int32_t input_segment_offset_end = broadcast_indices + ? input_segment_offset_start + 1 + : input_segment_offset_start + num_ads_b; + const auto based_segment = broadcast_indices + ? cat_ad_offsets_data[nT * b] + : cat_ad_offsets_data[nT * batch_offsets_data[b]]; + const auto input_segment_start = + cat_ad_offsets_data[input_segment_offset_start] - based_segment; + const auto input_segment_end = + cat_ad_offsets_data[input_segment_offset_end] - based_segment; + const auto num_elements = input_segment_end - input_segment_start; + const auto data_size = num_elements * sizeof(scalar_t); + if (broadcast_indices) { + for (auto j : c10::irange(num_ads_b)) { + std::memcpy( + output_data + output_segment_start + j * num_elements, + ad_indices_data + input_segment_start, + data_size); + } + } else { + std::memcpy( + output_data + output_segment_start, + ad_indices_data + input_segment_start, + data_size); + } + } + } + }); +} + Tensor reorder_batched_ad_indices_cpu( const Tensor& cat_ad_offsets, const Tensor& cat_ad_indices, @@ -1319,6 +1385,47 @@ Tensor reorder_batched_ad_indices_cpu( return reordered_cat_ad_indices; } +Tensor cat_reorder_batched_ad_indices_cpu( + const Tensor& cat_ad_offsets, + const std::vector& ad_indices, + const Tensor& reordered_cat_ad_offsets, + const Tensor& batch_offsets, + const int64_t num_ads_in_batch, + const bool broadcast_indices, + const int64_t total_num_indices, + const bool pinned_memory) { + TENSOR_ON_CPU(cat_ad_offsets); + for (const auto& t : ad_indices) { + TENSOR_ON_CPU(t); + } + TENSOR_ON_CPU(reordered_cat_ad_offsets); + TENSOR_ON_CPU(batch_offsets); + TORCH_CHECK_GE(total_num_indices, 0); + Tensor reordered_cat_ad_indices = at::empty( + {total_num_indices}, + ad_indices[0].options().pinned_memory(pinned_memory)); + AT_DISPATCH_INDEX_TYPES( + cat_ad_offsets.scalar_type(), + "cat_reorder_batched_ad_indices_cpu_kernel_1", + [&] { + AT_DISPATCH_ALL_TYPES( + ad_indices[0].scalar_type(), + "cat_reorder_batched_ad_indices_cpu_kernel_2", + [&] { + cat_reorder_batched_ad_indices_cpu_( + cat_ad_offsets, + ad_indices, + reordered_cat_ad_offsets, + batch_offsets, + num_ads_in_batch, + broadcast_indices, + reordered_cat_ad_indices); + }); + }); + + return reordered_cat_ad_indices; +} + Tensor offsets_range_cpu(const Tensor& offsets, int64_t range_size) { TENSOR_ON_CPU(offsets); TENSOR_NDIM_EQUALS(offsets, 1); @@ -2125,8 +2232,8 @@ void _permute_lengths_cpu_kernel( (num_threads + 1) * FALSE_SHARING_PAD, 0); // First parallel for: populate permuted_lengths, and compute per-thread - // summation of lengths (input_offsets_per_thread_cumsum) and permuted_lengths - // (output_offsets_per_thread_cumsum) + // summation of lengths (input_offsets_per_thread_cumsum) and + // permuted_lengths (output_offsets_per_thread_cumsum) at::parallel_for( 0, T * B, FALSE_SHARING_PAD, [&](int64_t tb_begin, int64_t tb_end) { index_t current_input_offset = 0; @@ -2558,6 +2665,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "reorder_batched_ad_lengths(Tensor cat_ad_lengths, Tensor batch_offsets, int num_ads_in_batch, bool broadcast_lengths=False) -> Tensor"); m.def( "reorder_batched_ad_indices(Tensor cat_ad_offsets, Tensor cat_ad_indices, Tensor reordered_cat_ad_offsets, Tensor batch_offsets, int num_ads_in_batch, bool broadcast_indices=False, int num_indices_after_broadcast=-1) -> Tensor"); + m.def( + "cat_reorder_batched_ad_indices(Tensor cat_ad_offsets, Tensor[] cat_ad_indices, Tensor reordered_cat_ad_offsets, Tensor batch_offsets, int num_ads_in_batch, bool broadcast_indices, int total_num_indices, bool pinned_memory=False) -> Tensor"); m.def("offsets_range(Tensor offsets, int range_size) -> Tensor"); m.def( "batched_unary_embeddings(Tensor weight, Tensor table_offsets, Tensor offsets, Tensor indices) -> Tensor"); @@ -2645,6 +2754,9 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { "reorder_batched_ad_lengths", fbgemm_gpu::reorder_batched_ad_lengths_cpu); DISPATCH_TO_CPU( "reorder_batched_ad_indices", fbgemm_gpu::reorder_batched_ad_indices_cpu); + DISPATCH_TO_CPU( + "cat_reorder_batched_ad_indices", + fbgemm_gpu::cat_reorder_batched_ad_indices_cpu); DISPATCH_TO_CPU("offsets_range", fbgemm_gpu::offsets_range_cpu); DISPATCH_TO_CPU( "batched_unary_embeddings", diff --git a/fbgemm_gpu/test/sparse_ops_test.py b/fbgemm_gpu/test/sparse_ops_test.py index 7facb05def..653e7d211b 100644 --- a/fbgemm_gpu/test/sparse_ops_test.py +++ b/fbgemm_gpu/test/sparse_ops_test.py @@ -1095,6 +1095,93 @@ def test_reorder_batched_ad_indices( else cat_ad_indices.view(B, T, A, L), ) + @given( + B=st.integers(min_value=1, max_value=20), + T=st.integers(min_value=1, max_value=20), + L=st.integers(min_value=2, max_value=20), + A=st.integers(min_value=1, max_value=20), + Dtype=st.sampled_from([torch.int32, torch.float, torch.int64]), + Itype=st.sampled_from([torch.int32, torch.int64]), + broadcast_indices=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) + def test_cat_reorder_batched_ad_indices_cpu( + self, + B: int, + T: int, + L: int, + A: int, + Dtype: torch.dtype, + Itype: torch.dtype, + broadcast_indices: bool, + ) -> None: + if broadcast_indices: + ad_indices = [ + ( + torch.randint( + low=0, + high=100, + size=(T * L,), + ) + .int() + .to(Dtype) + ) + for _ in range(B) + ] + cat_ad_lengths = torch.cat( + [torch.tensor([L for _ in range(T)]) for _ in range(B)], + 0, + ).int() + cat_ad_lengths_broadcasted = cat_ad_lengths.tile([A]) + cat_ad_indices = torch.cat(ad_indices, 0) + else: + ad_indices = [ + ( + torch.randint( + low=0, + high=100, + size=(T * A * L,), + ) + .int() + .to(Dtype) + ) + for _ in range(B) + ] + cat_ad_lengths = torch.cat( + [torch.tensor([L for _ in range(T * A)]) for _ in range(B)], + 0, + ).int() + cat_ad_lengths_broadcasted = cat_ad_lengths + cat_ad_indices = torch.cat(ad_indices, 0) + batch_offsets = torch.tensor([A * b for b in range(B + 1)]).int() + num_ads_in_batch = B * A + reordered_cat_ad_lengths = torch.ops.fbgemm.reorder_batched_ad_lengths( + cat_ad_lengths, batch_offsets, num_ads_in_batch, broadcast_indices + ) + torch.testing.assert_close(cat_ad_lengths_broadcasted, reordered_cat_ad_lengths) + + cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + cat_ad_lengths + ).to(Itype) + reordered_cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + reordered_cat_ad_lengths + ).to(Itype) + reordered_cat_ad_indices = torch.ops.fbgemm.cat_reorder_batched_ad_indices( + cat_ad_offsets, + ad_indices, + reordered_cat_ad_offsets, + batch_offsets, + num_ads_in_batch, + broadcast_indices, + B * T * A * L, + ) + torch.testing.assert_close( + reordered_cat_ad_indices.view(T, B, A, L).permute(1, 0, 2, 3), + cat_ad_indices.view(B, T, 1, L).tile([1, 1, A, 1]) + if broadcast_indices + else cat_ad_indices.view(B, T, A, L), + ) + @given( B=st.integers(min_value=1, max_value=20), T=st.integers(min_value=1, max_value=20),