Skip to content

Commit

Permalink
parallel reorder_batched_ads_indices/lengths
Browse files Browse the repository at this point in the history
Summary:
add parellelism for reoder batched ads indices/lengths
Meanwhile added a benchmark for such reoder_batched_ad_lengths

Differential Revision: https://internalfb.com/D48082346

fbshipit-source-id: b299b3481172916ad05b0c635fe52f7b2e90d16d
  • Loading branch information
xiaoruichao authored and facebook-github-bot committed Aug 11, 2023
1 parent 1b2746f commit 76d093c
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 58 deletions.
90 changes: 79 additions & 11 deletions fbgemm_gpu/bench/sparse_ops_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ def asynchronous_complete_cumsum_2d_ref(lengths: torch.Tensor) -> torch.Tensor:
@click.option("--dtype", type=click.Choice(["float", "long"]), default="long")
@click.option("--itype", type=click.Choice(["int", "long"]), default="int")
@click.option("--broadcast-indices", type=bool, default=True)
@click.option("--device", type=str, default="cpu")
def reorder_batched_ad_indices_bench(
batch_size: int,
table_size: int,
Expand All @@ -373,7 +374,9 @@ def reorder_batched_ad_indices_bench(
dtype: str,
itype: str,
broadcast_indices: bool,
device: str,
) -> None:

assert dtype == "float" or dtype == "long", "Only int and long are supported"
data_type = torch.int64 if dtype == "long" else torch.float
data_size = 8 if dtype == "long" else 4
Expand All @@ -389,7 +392,7 @@ def reorder_batched_ad_indices_bench(
size=(batch_size * table_size * length,),
)
.int()
.cuda()
.to(device)
.to(data_type)
)
cat_ad_lengths = (
Expand All @@ -401,7 +404,7 @@ def reorder_batched_ad_indices_bench(
0,
)
.int()
.cuda()
.to(device)
)
else:
cat_ad_indices = (
Expand All @@ -411,7 +414,7 @@ def reorder_batched_ad_indices_bench(
size=(batch_size * table_size * num_ads * length,),
)
.int()
.cuda()
.to(device)
.to(data_type)
)
cat_ad_lengths = (
Expand All @@ -423,23 +426,27 @@ def reorder_batched_ad_indices_bench(
0,
)
.int()
.cuda()
.to(device)
)

batch_offsets = (
torch.tensor([num_ads * b for b in range(batch_size + 1)]).int().cuda()
)
).to(device)
num_ads_in_batch = batch_size * num_ads
reordered_cat_ad_lengths = torch.ops.fbgemm.reorder_batched_ad_lengths(
cat_ad_lengths, batch_offsets, num_ads_in_batch, broadcast_indices
)
).to(device)

cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(cat_ad_lengths).to(
index_type
cat_ad_offsets = (
torch.ops.fbgemm.asynchronous_complete_cumsum(cat_ad_lengths)
.to(index_type)
.to(device)
)
reordered_cat_ad_offsets = (
torch.ops.fbgemm.asynchronous_complete_cumsum(reordered_cat_ad_lengths)
.to(index_type)
.to(device)
)
reordered_cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(
reordered_cat_ad_lengths
).to(index_type)
time, _ = benchmark_torch_function(
torch.ops.fbgemm.reorder_batched_ad_indices,
(
Expand All @@ -460,6 +467,67 @@ def reorder_batched_ad_indices_bench(
)


@cli.command()
@click.option("--batch-size", default=8192)
@click.option("--table-size", default=20)
@click.option("--length", default=50)
@click.option("--num-ads", default=100)
@click.option("--broadcast-indices", type=bool, default=True)
@click.option("--device", type=str, default="cpu")
def reorder_batched_ad_lengths_bench(
batch_size: int,
table_size: int,
length: int,
num_ads: int,
broadcast_indices: bool,
device: str,
) -> None:
if broadcast_indices:
cat_ad_lengths = (
torch.cat(
[
torch.tensor([length for _ in range(table_size)])
for _ in range(batch_size)
],
0,
)
.int()
.to(device)
)
else:
cat_ad_lengths = (
torch.cat(
[
torch.tensor([length for _ in range(table_size * num_ads)])
for _ in range(batch_size)
],
0,
)
.int()
.to(device)
)

batch_offsets = (
torch.tensor([num_ads * b for b in range(batch_size + 1)]).int().cuda()
).to(device)
num_ads_in_batch = batch_size * num_ads
time, _ = benchmark_torch_function(
torch.ops.fbgemm.reorder_batched_ad_lengths,
(
cat_ad_lengths,
batch_offsets,
num_ads_in_batch,
broadcast_indices,
),
num_warmups=100,
iters=1000,
)
num_bytes = batch_size * table_size * (num_ads + 1) * length * 4
logging.info(
f"fbgemm_gpu time: {time * 1000:.5f} ms ({num_bytes / time / 1e9:.5f} GB/s)"
)


@cli.command()
@click.option("--num-inputs", default=1024)
@click.option("--rows", default=100)
Expand Down
113 changes: 66 additions & 47 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1152,21 +1152,30 @@ void reorder_batched_ad_lengths_(
const auto* batch_offsets_data = batch_offsets.data_ptr<index_t>();
const auto* cat_ad_lengths_data = cat_ad_lengths.data_ptr<scalar_t>();
auto* output_data = output.data_ptr<scalar_t>();
for (const auto b : c10::irange(nB)) {
const auto num_ads_b = batch_offsets_data[b + 1] - batch_offsets_data[b];
for (const auto t : c10::irange(nT)) {
const int32_t input_segment_start = broadcast_lengths
? nT * b + t
: nT * batch_offsets_data[b] + t * num_ads_b;
const int32_t output_segment_start =
t * num_ads_in_batch + batch_offsets_data[b];
for (const auto i : c10::irange(num_ads_b)) {
output_data[output_segment_start + i] = broadcast_lengths
? cat_ad_lengths_data[input_segment_start]
: cat_ad_lengths_data[input_segment_start + i];
}
}
}
at::parallel_for(
0, nB * nT, FALSE_SHARING_PAD, [&](int64_t tb_begin, int64_t tb_end) {
auto b_begin = tb_begin / nT;
auto b_end = (tb_end + nT - 1) / nT;
for (const auto b : c10::irange(b_begin, b_end)) {
const auto num_ads_b =
batch_offsets_data[b + 1] - batch_offsets_data[b];
int64_t t_begin = (b == b_begin) ? tb_begin % nT : 0;
int64_t t_end =
(b == b_end - 1 && tb_end % nT != 0) ? tb_end % nT : nT;
for (const auto t : c10::irange(t_begin, t_end)) {
const int32_t input_segment_start = broadcast_lengths
? nT * b + t
: nT * batch_offsets_data[b] + t * num_ads_b;
const int32_t output_segment_start =
t * num_ads_in_batch + batch_offsets_data[b];
for (const auto i : c10::irange(num_ads_b)) {
output_data[output_segment_start + i] = broadcast_lengths
? cat_ad_lengths_data[input_segment_start]
: cat_ad_lengths_data[input_segment_start + i];
}
}
}
});
}

Tensor reorder_batched_ad_lengths_cpu(
Expand Down Expand Up @@ -1221,40 +1230,50 @@ void reorder_batched_ad_indices_cpu_(
reordered_cat_ad_offsets.data_ptr<index_t>();
const auto* cat_ad_indices_data = cat_ad_indices.data_ptr<scalar_t>();
auto* output_data = output.data_ptr<scalar_t>();
for (const auto b : c10::irange(nB)) {
const auto num_ads_b = batch_offsets_data[b + 1] - batch_offsets_data[b];
for (const auto t : c10::irange(nT)) {
const auto output_segment_offset_start =
t * num_ads_in_batch + batch_offsets_data[b];
const auto output_segment_start =
reordered_cat_ad_offsets_data[output_segment_offset_start];
const int32_t input_segment_offset_start = broadcast_indices
? nT * b + t
: nT * batch_offsets_data[b] + t * num_ads_b;
const int32_t input_segment_offset_end = broadcast_indices
? input_segment_offset_start + 1
: input_segment_offset_start + num_ads_b;
const auto input_segment_start =
cat_ad_offsets_data[input_segment_offset_start];
const auto input_segment_end =
cat_ad_offsets_data[input_segment_offset_end];
const auto num_elements = input_segment_end - input_segment_start;

if (broadcast_indices) {
for (auto j : c10::irange(num_ads_b)) {
for (auto i : c10::irange(num_elements)) {
output_data[output_segment_start + j * num_elements + i] =
cat_ad_indices_data[input_segment_start + i];
at::parallel_for(
0, nB * nT, FALSE_SHARING_PAD, [&](int64_t tb_begin, int64_t tb_end) {
auto b_begin = tb_begin / nT;
auto b_end = (tb_end + nT - 1) / nT;

for (const auto b : c10::irange(b_begin, b_end)) {
const auto num_ads_b =
batch_offsets_data[b + 1] - batch_offsets_data[b];
int64_t t_begin = (b == b_begin) ? tb_begin % nT : 0;
int64_t t_end =
(b == b_end - 1 && tb_end % nT != 0) ? tb_end % nT : nT;
for (const auto t : c10::irange(t_begin, t_end)) {
const auto output_segment_offset_start =
t * num_ads_in_batch + batch_offsets_data[b];
const auto output_segment_start =
reordered_cat_ad_offsets_data[output_segment_offset_start];
const int32_t input_segment_offset_start = broadcast_indices
? nT * b + t
: nT * batch_offsets_data[b] + t * num_ads_b;
const int32_t input_segment_offset_end = broadcast_indices
? input_segment_offset_start + 1
: input_segment_offset_start + num_ads_b;
const auto input_segment_start =
cat_ad_offsets_data[input_segment_offset_start];
const auto input_segment_end =
cat_ad_offsets_data[input_segment_offset_end];
const auto num_elements = input_segment_end - input_segment_start;

if (broadcast_indices) {
for (auto j : c10::irange(num_ads_b)) {
for (auto i : c10::irange(num_elements)) {
output_data[output_segment_start + j * num_elements + i] =
cat_ad_indices_data[input_segment_start + i];
}
}
} else {
for (auto i : c10::irange(num_elements)) {
output_data[output_segment_start + i] =
cat_ad_indices_data[input_segment_start + i];
}
}
}
}
} else {
for (auto i : c10::irange(num_elements)) {
output_data[output_segment_start + i] =
cat_ad_indices_data[input_segment_start + i];
}
}
}
}
});
}

Tensor reorder_batched_ad_indices_cpu(
Expand Down

0 comments on commit 76d093c

Please sign in to comment.