Skip to content

Commit

Permalink
combine cat + reorder for combined indice into one pass (pytorch#1932)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1932

as title. combine cat + reorder into one pass for combined indice to avoid unncessary copy of cat_indices.

Also added a micro benchmark for such case

Differential Revision: D48058732

fbshipit-source-id: 1523f7862bd8bebae41b1dec8a3cee696cfd4aac
  • Loading branch information
garroud authored and facebook-github-bot committed Aug 11, 2023
1 parent f6b1529 commit dbc4bb4
Show file tree
Hide file tree
Showing 4 changed files with 399 additions and 3 deletions.
188 changes: 188 additions & 0 deletions fbgemm_gpu/bench/sparse_ops_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,5 +679,193 @@ def batch_group_index_select_bwd(
)


@cli.command()
@click.option("--batch-size", default=8192)
@click.option("--table-size", default=20)
@click.option("--length", default=50)
@click.option("--num-ads", default=100)
@click.option("--dtype", type=click.Choice(["float", "long"]), default="long")
@click.option("--itype", type=click.Choice(["int", "long"]), default="int")
@click.option("--broadcast-indices", type=bool, default=True)
def cat_reorder_batched_ad_indices_bench(
batch_size: int,
table_size: int,
length: int,
num_ads: int,
dtype: str,
itype: str,
broadcast_indices: bool,
) -> None:
assert dtype == "float" or dtype == "long", "Only int and long are supported"
data_type = torch.int64 if dtype == "long" else torch.float
data_size = 8 if dtype == "long" else 4

assert itype == "int" or itype == "long", "Only int and long are supported"

if broadcast_indices:
ad_indices = [
(
torch.randint(
low=0,
high=100,
size=(table_size * length,),
)
.int()
.to(data_type)
)
for _ in range(batch_size)
]
ad_lengths = [
torch.tensor([length for _ in range(table_size)]).int()
for _ in range(batch_size)
]
else:
ad_indices = [
(
torch.randint(
low=0,
high=100,
size=(table_size * num_ads * length,),
)
.int()
.to(data_type)
)
for _ in range(batch_size)
]
ad_lengths = [
torch.tensor([length for _ in range(table_size * num_ads)]).int()
for _ in range(batch_size)
]

batch_offsets = torch.tensor([num_ads * b for b in range(batch_size + 1)]).int()
num_ads_in_batch = batch_size * num_ads

# pyre-ignore
def pass_1(ad_indices, ad_lengths, batch_offsets, num_ads_in_batch):
cat_ad_lengths = torch.cat(ad_lengths, 0).to("cuda", non_blocking=True)
cat_ad_indices = torch.cat(ad_indices, 0).to("cuda", non_blocking=True)
batch_offsets = batch_offsets.to("cuda", non_blocking=True)
reordered_cat_ad_lengths = torch.ops.fbgemm.reorder_batched_ad_lengths(
cat_ad_lengths, batch_offsets, num_ads_in_batch, broadcast_indices
)
cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(cat_ad_lengths)
reordered_cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(
reordered_cat_ad_lengths
)
reordered_cat_ad_indices = torch.ops.fbgemm.reorder_batched_ad_indices(
cat_ad_offsets,
cat_ad_indices,
reordered_cat_ad_offsets,
batch_offsets,
num_ads_in_batch,
broadcast_indices,
batch_size * table_size * num_ads * length,
)

return reordered_cat_ad_indices, reordered_cat_ad_lengths

# process length on device and process indice on device
# pyre-ignore
def pass_2(ad_indices, ad_lengths, batch_offsets, num_ads_in_batch):
cat_ad_lengths = torch.cat(ad_lengths, 0)

reordered_cat_ad_lengths = torch.ops.fbgemm.reorder_batched_ad_lengths(
cat_ad_lengths, batch_offsets, num_ads_in_batch, broadcast_indices
)
cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(cat_ad_lengths)
reordered_cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(
reordered_cat_ad_lengths
)
cat_ad_indices = torch.cat(ad_indices, 0)

reordered_cat_ad_indices = torch.ops.fbgemm.reorder_batched_ad_indices(
cat_ad_offsets.to("cuda", non_blocking=True),
cat_ad_indices.to("cuda", non_blocking=True),
reordered_cat_ad_offsets.to("cuda", non_blocking=True),
batch_offsets.to("cuda", non_blocking=True),
num_ads_in_batch,
broadcast_indices,
batch_size * table_size * num_ads * length,
)

return reordered_cat_ad_indices, reordered_cat_ad_lengths.to(
"cuda", non_blocking=True
)

# minimize GPU workload + unfused cat + reorder
# pyre-ignore
def pass_3(ad_indices, ad_lengths, batch_offsets, num_ads_in_batch):
cat_ad_lengths = torch.cat(ad_lengths, 0)
reordered_cat_ad_lengths = torch.ops.fbgemm.reorder_batched_ad_lengths(
cat_ad_lengths, batch_offsets, num_ads_in_batch, broadcast_indices
)

cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(cat_ad_lengths)
reordered_cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(
reordered_cat_ad_lengths
)
cat_ad_indices = torch.cat(ad_indices, 0)

reordered_cat_ad_indices = torch.ops.fbgemm.reorder_batched_ad_indices(
cat_ad_offsets,
cat_ad_indices,
reordered_cat_ad_offsets,
batch_offsets,
num_ads_in_batch,
broadcast_indices,
batch_size * table_size * num_ads * length,
)

return reordered_cat_ad_indices.to(
"cuda", non_blocking=True
), reordered_cat_ad_lengths.to("cuda", non_blocking=True)

# minimize GPU workload + fuse cat + reorder
# pyre-ignore
def pass_4(ad_indices, ad_lengths, batch_offsets, num_ads_in_batch):
cat_ad_lengths = torch.cat(ad_lengths, 0)
reordered_cat_ad_lengths = torch.ops.fbgemm.reorder_batched_ad_lengths(
cat_ad_lengths, batch_offsets, num_ads_in_batch, broadcast_indices
)

cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(cat_ad_lengths)
reordered_cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(
reordered_cat_ad_lengths
)

reordered_cat_ad_indices = torch.ops.fbgemm.cat_reorder_batched_ad_indices(
cat_ad_offsets,
ad_indices,
reordered_cat_ad_offsets,
batch_offsets,
num_ads_in_batch,
broadcast_indices,
batch_size * table_size * num_ads * length,
)

return reordered_cat_ad_indices.to(
"cuda", non_blocking=True
), reordered_cat_ad_lengths.to("cuda", non_blocking=True)

num_bytes = batch_size * table_size * (num_ads + 1) * length * data_size

# pyre-ignore
def ben(fn, name, ad_indices, ad_lengths, batch_offsets, num_ads_in_batch):
time, _ = benchmark_torch_function(
fn,
(ad_indices, ad_lengths, batch_offsets, num_ads_in_batch),
num_warmups=50,
iters=500,
)
logging.info(
f"{name} fbgemm_gpu time: {time * 1000:.5f} ms ({num_bytes / time / 1e9:.5f} GB/s)"
)

ben(pass_1, "pass_1", ad_indices, ad_lengths, batch_offsets, num_ads_in_batch)
ben(pass_2, "pass_2", ad_indices, ad_lengths, batch_offsets, num_ads_in_batch)
ben(pass_3, "pass_3", ad_indices, ad_lengths, batch_offsets, num_ads_in_batch)
ben(pass_4, "pass_4", ad_indices, ad_lengths, batch_offsets, num_ads_in_batch)


if __name__ == "__main__":
cli()
11 changes: 10 additions & 1 deletion fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,16 @@ at::Tensor reorder_batched_ad_indices_cpu(
const int64_t num_ads_in_batch,
const bool broadcast_indices = false,
const int64_t num_indices_after_broadcast = -1);

///@ingroup sparse-data-cpu
at::Tensor cat_reorder_batched_ad_indices_cpu(
const at::Tensor& cat_ad_offsets,
const std::vector<at::Tensor>& cat_ad_indices,
const at::Tensor& reordered_cat_ad_offsets,
const at::Tensor& batch_offsets,
const int64_t num_ads_in_batch,
const bool broadcast_indices,
const int64_t num_indices_after_broadcast,
const bool pinned_memory = false);
at::Tensor recat_embedding_grad_output_cuda(
at::Tensor grad_output, // [B_local][T_global][D]
const std::vector<int64_t>& num_features_per_rank);
Expand Down
116 changes: 114 additions & 2 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1276,6 +1276,72 @@ void reorder_batched_ad_indices_cpu_(
});
}

template <typename index_t, typename scalar_t>
void cat_reorder_batched_ad_indices_cpu_(
const Tensor& cat_ad_offsets,
const std::vector<Tensor>& ad_indices,
const Tensor& reordered_cat_ad_offsets,
const Tensor& batch_offsets,
const int64_t num_ads_in_batch,
const bool broadcast_indices,
Tensor& output) {
const int64_t nB = batch_offsets.numel() - 1;
const int64_t nT = (reordered_cat_ad_offsets.numel() - 1) / num_ads_in_batch;

const auto* batch_offsets_data = batch_offsets.data_ptr<int32_t>();
const auto* cat_ad_offsets_data = cat_ad_offsets.data_ptr<index_t>();
const auto* reordered_cat_ad_offsets_data =
reordered_cat_ad_offsets.data_ptr<index_t>();
auto* output_data = output.data_ptr<scalar_t>();
at::parallel_for(
0, nB * nT, FALSE_SHARING_PAD, [&](int64_t tb_begin, int64_t tb_end) {
auto b_begin = tb_begin / nT;
auto b_end = (tb_end + nT - 1) / nT;
for (auto b : c10::irange(b_begin, b_end)) {
const auto* ad_indices_data = ad_indices[b].data_ptr<scalar_t>();
const auto num_ads_b =
batch_offsets_data[b + 1] - batch_offsets_data[b];
int64_t t_begin = (b == b_begin) ? tb_begin % nT : 0;
int64_t t_end =
(b == b_end - 1 && tb_end % nT != 0) ? tb_end % nT : nT;
for (auto t : c10::irange(t_begin, t_end)) {
const auto output_segment_offset_start =
t * num_ads_in_batch + batch_offsets_data[b];
const auto output_segment_start =
reordered_cat_ad_offsets_data[output_segment_offset_start];
const int32_t input_segment_offset_start = broadcast_indices
? nT * b + t
: nT * batch_offsets_data[b] + t * num_ads_b;
const int32_t input_segment_offset_end = broadcast_indices
? input_segment_offset_start + 1
: input_segment_offset_start + num_ads_b;
const auto based_segment = broadcast_indices
? cat_ad_offsets_data[nT * b]
: cat_ad_offsets_data[nT * batch_offsets_data[b]];
const auto input_segment_start =
cat_ad_offsets_data[input_segment_offset_start] - based_segment;
const auto input_segment_end =
cat_ad_offsets_data[input_segment_offset_end] - based_segment;
const auto num_elements = input_segment_end - input_segment_start;
const auto data_size = num_elements * sizeof(scalar_t);
if (broadcast_indices) {
for (auto j : c10::irange(num_ads_b)) {
std::memcpy(
output_data + output_segment_start + j * num_elements,
ad_indices_data + input_segment_start,
data_size);
}
} else {
std::memcpy(
output_data + output_segment_start,
ad_indices_data + input_segment_start,
data_size);
}
}
}
});
}

Tensor reorder_batched_ad_indices_cpu(
const Tensor& cat_ad_offsets,
const Tensor& cat_ad_indices,
Expand Down Expand Up @@ -1319,6 +1385,47 @@ Tensor reorder_batched_ad_indices_cpu(
return reordered_cat_ad_indices;
}

Tensor cat_reorder_batched_ad_indices_cpu(
const Tensor& cat_ad_offsets,
const std::vector<Tensor>& ad_indices,
const Tensor& reordered_cat_ad_offsets,
const Tensor& batch_offsets,
const int64_t num_ads_in_batch,
const bool broadcast_indices,
const int64_t total_num_indices,
const bool pinned_memory) {
TENSOR_ON_CPU(cat_ad_offsets);
for (const auto& t : ad_indices) {
TENSOR_ON_CPU(t);
}
TENSOR_ON_CPU(reordered_cat_ad_offsets);
TENSOR_ON_CPU(batch_offsets);
TORCH_CHECK_GE(total_num_indices, 0);
Tensor reordered_cat_ad_indices = at::empty(
{total_num_indices},
ad_indices[0].options().pinned_memory(pinned_memory));
AT_DISPATCH_INDEX_TYPES(
cat_ad_offsets.scalar_type(),
"cat_reorder_batched_ad_indices_cpu_kernel_1",
[&] {
AT_DISPATCH_ALL_TYPES(
ad_indices[0].scalar_type(),
"cat_reorder_batched_ad_indices_cpu_kernel_2",
[&] {
cat_reorder_batched_ad_indices_cpu_<index_t, scalar_t>(
cat_ad_offsets,
ad_indices,
reordered_cat_ad_offsets,
batch_offsets,
num_ads_in_batch,
broadcast_indices,
reordered_cat_ad_indices);
});
});

return reordered_cat_ad_indices;
}

Tensor offsets_range_cpu(const Tensor& offsets, int64_t range_size) {
TENSOR_ON_CPU(offsets);
TENSOR_NDIM_EQUALS(offsets, 1);
Expand Down Expand Up @@ -2125,8 +2232,8 @@ void _permute_lengths_cpu_kernel(
(num_threads + 1) * FALSE_SHARING_PAD, 0);

// First parallel for: populate permuted_lengths, and compute per-thread
// summation of lengths (input_offsets_per_thread_cumsum) and permuted_lengths
// (output_offsets_per_thread_cumsum)
// summation of lengths (input_offsets_per_thread_cumsum) and
// permuted_lengths (output_offsets_per_thread_cumsum)
at::parallel_for(
0, T * B, FALSE_SHARING_PAD, [&](int64_t tb_begin, int64_t tb_end) {
index_t current_input_offset = 0;
Expand Down Expand Up @@ -2558,6 +2665,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"reorder_batched_ad_lengths(Tensor cat_ad_lengths, Tensor batch_offsets, int num_ads_in_batch, bool broadcast_lengths=False) -> Tensor");
m.def(
"reorder_batched_ad_indices(Tensor cat_ad_offsets, Tensor cat_ad_indices, Tensor reordered_cat_ad_offsets, Tensor batch_offsets, int num_ads_in_batch, bool broadcast_indices=False, int num_indices_after_broadcast=-1) -> Tensor");
m.def(
"cat_reorder_batched_ad_indices(Tensor cat_ad_offsets, Tensor[] cat_ad_indices, Tensor reordered_cat_ad_offsets, Tensor batch_offsets, int num_ads_in_batch, bool broadcast_indices, int total_num_indices, bool pinned_memory=False) -> Tensor");
m.def("offsets_range(Tensor offsets, int range_size) -> Tensor");
m.def(
"batched_unary_embeddings(Tensor weight, Tensor table_offsets, Tensor offsets, Tensor indices) -> Tensor");
Expand Down Expand Up @@ -2645,6 +2754,9 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
"reorder_batched_ad_lengths", fbgemm_gpu::reorder_batched_ad_lengths_cpu);
DISPATCH_TO_CPU(
"reorder_batched_ad_indices", fbgemm_gpu::reorder_batched_ad_indices_cpu);
DISPATCH_TO_CPU(
"cat_reorder_batched_ad_indices",
fbgemm_gpu::cat_reorder_batched_ad_indices_cpu);
DISPATCH_TO_CPU("offsets_range", fbgemm_gpu::offsets_range_cpu);
DISPATCH_TO_CPU(
"batched_unary_embeddings",
Expand Down
Loading

0 comments on commit dbc4bb4

Please sign in to comment.