Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

combine cat + reorder for combined indice into one pass #1932

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 188 additions & 0 deletions fbgemm_gpu/bench/sparse_ops_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,5 +678,193 @@ def batch_group_index_select_bwd(
)


@cli.command()
@click.option("--batch-size", default=8192)
@click.option("--table-size", default=20)
@click.option("--length", default=50)
@click.option("--num-ads", default=100)
@click.option("--dtype", type=click.Choice(["float", "long"]), default="long")
@click.option("--itype", type=click.Choice(["int", "long"]), default="int")
@click.option("--broadcast-indices", type=bool, default=True)
def cat_reorder_batched_ad_indices_bench(
batch_size: int,
table_size: int,
length: int,
num_ads: int,
dtype: str,
itype: str,
broadcast_indices: bool,
) -> None:
assert dtype == "float" or dtype == "long", "Only int and long are supported"
data_type = torch.int64 if dtype == "long" else torch.float
data_size = 8 if dtype == "long" else 4

assert itype == "int" or itype == "long", "Only int and long are supported"

if broadcast_indices:
ad_indices = [
(
torch.randint(
low=0,
high=100,
size=(table_size * length,),
)
.int()
.to(data_type)
)
for _ in range(batch_size)
]
ad_lengths = [
torch.tensor([length for _ in range(table_size)]).int()
for _ in range(batch_size)
]
else:
ad_indices = [
(
torch.randint(
low=0,
high=100,
size=(table_size * num_ads * length,),
)
.int()
.to(data_type)
)
for _ in range(batch_size)
]
ad_lengths = [
torch.tensor([length for _ in range(table_size * num_ads)]).int()
for _ in range(batch_size)
]

batch_offsets = torch.tensor([num_ads * b for b in range(batch_size + 1)]).int()
num_ads_in_batch = batch_size * num_ads

# pyre-ignore
def pass_1(ad_indices, ad_lengths, batch_offsets, num_ads_in_batch):
cat_ad_lengths = torch.cat(ad_lengths, 0).to("cuda", non_blocking=True)
cat_ad_indices = torch.cat(ad_indices, 0).to("cuda", non_blocking=True)
batch_offsets = batch_offsets.to("cuda", non_blocking=True)
reordered_cat_ad_lengths = torch.ops.fbgemm.reorder_batched_ad_lengths(
cat_ad_lengths, batch_offsets, num_ads_in_batch, broadcast_indices
)
cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(cat_ad_lengths)
reordered_cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(
reordered_cat_ad_lengths
)
reordered_cat_ad_indices = torch.ops.fbgemm.reorder_batched_ad_indices(
cat_ad_offsets,
cat_ad_indices,
reordered_cat_ad_offsets,
batch_offsets,
num_ads_in_batch,
broadcast_indices,
batch_size * table_size * num_ads * length,
)

return reordered_cat_ad_indices, reordered_cat_ad_lengths

# process length on device and process indice on device
# pyre-ignore
def pass_2(ad_indices, ad_lengths, batch_offsets, num_ads_in_batch):
cat_ad_lengths = torch.cat(ad_lengths, 0)

reordered_cat_ad_lengths = torch.ops.fbgemm.reorder_batched_ad_lengths(
cat_ad_lengths, batch_offsets, num_ads_in_batch, broadcast_indices
)
cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(cat_ad_lengths)
reordered_cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(
reordered_cat_ad_lengths
)
cat_ad_indices = torch.cat(ad_indices, 0)

reordered_cat_ad_indices = torch.ops.fbgemm.reorder_batched_ad_indices(
cat_ad_offsets.to("cuda", non_blocking=True),
cat_ad_indices.to("cuda", non_blocking=True),
reordered_cat_ad_offsets.to("cuda", non_blocking=True),
batch_offsets.to("cuda", non_blocking=True),
num_ads_in_batch,
broadcast_indices,
batch_size * table_size * num_ads * length,
)

return reordered_cat_ad_indices, reordered_cat_ad_lengths.to(
"cuda", non_blocking=True
)

# minimize GPU workload + unfused cat + reorder
# pyre-ignore
def pass_3(ad_indices, ad_lengths, batch_offsets, num_ads_in_batch):
cat_ad_lengths = torch.cat(ad_lengths, 0)
reordered_cat_ad_lengths = torch.ops.fbgemm.reorder_batched_ad_lengths(
cat_ad_lengths, batch_offsets, num_ads_in_batch, broadcast_indices
)

cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(cat_ad_lengths)
reordered_cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(
reordered_cat_ad_lengths
)
cat_ad_indices = torch.cat(ad_indices, 0)

reordered_cat_ad_indices = torch.ops.fbgemm.reorder_batched_ad_indices(
cat_ad_offsets,
cat_ad_indices,
reordered_cat_ad_offsets,
batch_offsets,
num_ads_in_batch,
broadcast_indices,
batch_size * table_size * num_ads * length,
)

return reordered_cat_ad_indices.to(
"cuda", non_blocking=True
), reordered_cat_ad_lengths.to("cuda", non_blocking=True)

# minimize GPU workload + fuse cat + reorder
# pyre-ignore
def pass_4(ad_indices, ad_lengths, batch_offsets, num_ads_in_batch):
cat_ad_lengths = torch.cat(ad_lengths, 0)
reordered_cat_ad_lengths = torch.ops.fbgemm.reorder_batched_ad_lengths(
cat_ad_lengths, batch_offsets, num_ads_in_batch, broadcast_indices
)

cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(cat_ad_lengths)
reordered_cat_ad_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(
reordered_cat_ad_lengths
)

reordered_cat_ad_indices = torch.ops.fbgemm.cat_reorder_batched_ad_indices(
cat_ad_offsets,
ad_indices,
reordered_cat_ad_offsets,
batch_offsets,
num_ads_in_batch,
broadcast_indices,
batch_size * table_size * num_ads * length,
)

return reordered_cat_ad_indices.to(
"cuda", non_blocking=True
), reordered_cat_ad_lengths.to("cuda", non_blocking=True)

num_bytes = batch_size * table_size * (num_ads + 1) * length * data_size

# pyre-ignore
def ben(fn, name, ad_indices, ad_lengths, batch_offsets, num_ads_in_batch):
time, _ = benchmark_torch_function(
fn,
(ad_indices, ad_lengths, batch_offsets, num_ads_in_batch),
num_warmups=50,
iters=500,
)
logging.info(
f"{name} fbgemm_gpu time: {time * 1000:.5f} ms ({num_bytes / time / 1e9:.5f} GB/s)"
)

ben(pass_1, "pass_1", ad_indices, ad_lengths, batch_offsets, num_ads_in_batch)
ben(pass_2, "pass_2", ad_indices, ad_lengths, batch_offsets, num_ads_in_batch)
ben(pass_3, "pass_3", ad_indices, ad_lengths, batch_offsets, num_ads_in_batch)
ben(pass_4, "pass_4", ad_indices, ad_lengths, batch_offsets, num_ads_in_batch)


if __name__ == "__main__":
cli()
11 changes: 10 additions & 1 deletion fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,16 @@ at::Tensor reorder_batched_ad_indices_cpu(
const int64_t num_ads_in_batch,
const bool broadcast_indices = false,
const int64_t num_indices_after_broadcast = -1);

///@ingroup sparse-data-cpu
at::Tensor cat_reorder_batched_ad_indices_cpu(
const at::Tensor& cat_ad_offsets,
const std::vector<at::Tensor>& cat_ad_indices,
const at::Tensor& reordered_cat_ad_offsets,
const at::Tensor& batch_offsets,
const int64_t num_ads_in_batch,
const bool broadcast_indices,
const int64_t num_indices_after_broadcast,
const bool pinned_memory = false);
at::Tensor recat_embedding_grad_output_cuda(
at::Tensor grad_output, // [B_local][T_global][D]
const std::vector<int64_t>& num_features_per_rank);
Expand Down
116 changes: 114 additions & 2 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1276,6 +1276,72 @@ void reorder_batched_ad_indices_cpu_(
});
}

template <typename index_t, typename scalar_t>
void cat_reorder_batched_ad_indices_cpu_(
const Tensor& cat_ad_offsets,
const std::vector<Tensor>& ad_indices,
const Tensor& reordered_cat_ad_offsets,
const Tensor& batch_offsets,
const int64_t num_ads_in_batch,
const bool broadcast_indices,
Tensor& output) {
const int64_t nB = batch_offsets.numel() - 1;
const int64_t nT = (reordered_cat_ad_offsets.numel() - 1) / num_ads_in_batch;

const auto* batch_offsets_data = batch_offsets.data_ptr<int32_t>();
const auto* cat_ad_offsets_data = cat_ad_offsets.data_ptr<index_t>();
const auto* reordered_cat_ad_offsets_data =
reordered_cat_ad_offsets.data_ptr<index_t>();
auto* output_data = output.data_ptr<scalar_t>();
at::parallel_for(
0, nB * nT, FALSE_SHARING_PAD, [&](int64_t tb_begin, int64_t tb_end) {
auto b_begin = tb_begin / nT;
auto b_end = (tb_end + nT - 1) / nT;
for (auto b : c10::irange(b_begin, b_end)) {
const auto* ad_indices_data = ad_indices[b].data_ptr<scalar_t>();
const auto num_ads_b =
batch_offsets_data[b + 1] - batch_offsets_data[b];
int64_t t_begin = (b == b_begin) ? tb_begin % nT : 0;
int64_t t_end =
(b == b_end - 1 && tb_end % nT != 0) ? tb_end % nT : nT;
for (auto t : c10::irange(t_begin, t_end)) {
const auto output_segment_offset_start =
t * num_ads_in_batch + batch_offsets_data[b];
const auto output_segment_start =
reordered_cat_ad_offsets_data[output_segment_offset_start];
const int32_t input_segment_offset_start = broadcast_indices
? nT * b + t
: nT * batch_offsets_data[b] + t * num_ads_b;
const int32_t input_segment_offset_end = broadcast_indices
? input_segment_offset_start + 1
: input_segment_offset_start + num_ads_b;
const auto based_segment = broadcast_indices
? cat_ad_offsets_data[nT * b]
: cat_ad_offsets_data[nT * batch_offsets_data[b]];
const auto input_segment_start =
cat_ad_offsets_data[input_segment_offset_start] - based_segment;
const auto input_segment_end =
cat_ad_offsets_data[input_segment_offset_end] - based_segment;
const auto num_elements = input_segment_end - input_segment_start;
const auto data_size = num_elements * sizeof(scalar_t);
if (broadcast_indices) {
for (auto j : c10::irange(num_ads_b)) {
std::memcpy(
output_data + output_segment_start + j * num_elements,
ad_indices_data + input_segment_start,
data_size);
}
} else {
std::memcpy(
output_data + output_segment_start,
ad_indices_data + input_segment_start,
data_size);
}
}
}
});
}

Tensor reorder_batched_ad_indices_cpu(
const Tensor& cat_ad_offsets,
const Tensor& cat_ad_indices,
Expand Down Expand Up @@ -1320,6 +1386,47 @@ Tensor reorder_batched_ad_indices_cpu(
return reordered_cat_ad_indices;
}

Tensor cat_reorder_batched_ad_indices_cpu(
const Tensor& cat_ad_offsets,
const std::vector<Tensor>& ad_indices,
const Tensor& reordered_cat_ad_offsets,
const Tensor& batch_offsets,
const int64_t num_ads_in_batch,
const bool broadcast_indices,
const int64_t total_num_indices,
const bool pinned_memory) {
TENSOR_ON_CPU(cat_ad_offsets);
for (const auto& t : ad_indices) {
TENSOR_ON_CPU(t);
}
TENSOR_ON_CPU(reordered_cat_ad_offsets);
TENSOR_ON_CPU(batch_offsets);
TORCH_CHECK_GE(total_num_indices, 0);
Tensor reordered_cat_ad_indices = at::empty(
{total_num_indices},
ad_indices[0].options().pinned_memory(pinned_memory));
AT_DISPATCH_INDEX_TYPES(
cat_ad_offsets.scalar_type(),
"cat_reorder_batched_ad_indices_cpu_kernel_1",
[&] {
AT_DISPATCH_ALL_TYPES(
ad_indices[0].scalar_type(),
"cat_reorder_batched_ad_indices_cpu_kernel_2",
[&] {
cat_reorder_batched_ad_indices_cpu_<index_t, scalar_t>(
cat_ad_offsets,
ad_indices,
reordered_cat_ad_offsets,
batch_offsets,
num_ads_in_batch,
broadcast_indices,
reordered_cat_ad_indices);
});
});

return reordered_cat_ad_indices;
}

Tensor offsets_range_cpu(const Tensor& offsets, int64_t range_size) {
TENSOR_ON_CPU(offsets);
TENSOR_NDIM_EQUALS(offsets, 1);
Expand Down Expand Up @@ -2126,8 +2233,8 @@ void _permute_lengths_cpu_kernel(
(num_threads + 1) * FALSE_SHARING_PAD, 0);

// First parallel for: populate permuted_lengths, and compute per-thread
// summation of lengths (input_offsets_per_thread_cumsum) and permuted_lengths
// (output_offsets_per_thread_cumsum)
// summation of lengths (input_offsets_per_thread_cumsum) and
// permuted_lengths (output_offsets_per_thread_cumsum)
at::parallel_for(
0, T * B, FALSE_SHARING_PAD, [&](int64_t tb_begin, int64_t tb_end) {
index_t current_input_offset = 0;
Expand Down Expand Up @@ -2559,6 +2666,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"reorder_batched_ad_lengths(Tensor cat_ad_lengths, Tensor batch_offsets, int num_ads_in_batch, bool broadcast_lengths=False) -> Tensor");
m.def(
"reorder_batched_ad_indices(Tensor cat_ad_offsets, Tensor cat_ad_indices, Tensor reordered_cat_ad_offsets, Tensor batch_offsets, int num_ads_in_batch, bool broadcast_indices=False, int num_indices_after_broadcast=-1) -> Tensor");
m.def(
"cat_reorder_batched_ad_indices(Tensor cat_ad_offsets, Tensor[] cat_ad_indices, Tensor reordered_cat_ad_offsets, Tensor batch_offsets, int num_ads_in_batch, bool broadcast_indices, int total_num_indices, bool pinned_memory=False) -> Tensor");
m.def("offsets_range(Tensor offsets, int range_size) -> Tensor");
m.def(
"batched_unary_embeddings(Tensor weight, Tensor table_offsets, Tensor offsets, Tensor indices) -> Tensor");
Expand Down Expand Up @@ -2646,6 +2755,9 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
"reorder_batched_ad_lengths", fbgemm_gpu::reorder_batched_ad_lengths_cpu);
DISPATCH_TO_CPU(
"reorder_batched_ad_indices", fbgemm_gpu::reorder_batched_ad_indices_cpu);
DISPATCH_TO_CPU(
"cat_reorder_batched_ad_indices",
fbgemm_gpu::cat_reorder_batched_ad_indices_cpu);
DISPATCH_TO_CPU("offsets_range", fbgemm_gpu::offsets_range_cpu);
DISPATCH_TO_CPU(
"batched_unary_embeddings",
Expand Down
Loading