Skip to content

Commit

Permalink
[C++][Compute] Refactor rank function implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
pitrou committed Jan 7, 2025
1 parent e12bc56 commit 34a9a86
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 82 deletions.
138 changes: 62 additions & 76 deletions cpp/src/arrow/compute/kernels/vector_rank.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,45 +28,60 @@ namespace {
// ----------------------------------------------------------------------
// Rank implementation

template <typename ValueSelector,
typename T = std::decay_t<std::invoke_result_t<ValueSelector, int64_t>>>
constexpr uint64_t kDuplicateMask = 1ULL << 63;

template <typename ValueSelector>
void MarkDuplicates(const NullPartitionResult& sorted, ValueSelector&& value_selector) {
using T = std::decay_t<decltype(value_selector(uint64_t(0)))>;

// Process non-nulls
if (sorted.non_nulls_end != sorted.non_nulls_begin) {
auto it = sorted.non_nulls_begin;
T prev_value = value_selector(*it);
T curr_value{};
while (++it < sorted.non_nulls_end) {
curr_value = value_selector(*it);
if (curr_value == prev_value) {
// Mark as duplicate
*it |= kDuplicateMask;
}
prev_value = curr_value;
}
}
// Process nulls
if (sorted.nulls_end != sorted.nulls_begin) {
auto it = sorted.nulls_begin;
// Mark all other nulles as duplicate
while (++it < sorted.nulls_end) {
*it |= kDuplicateMask;
}
}
}

bool NeedsDuplicates(RankOptions::Tiebreaker tiebreaker) {
return tiebreaker != RankOptions::First;
}

Result<Datum> CreateRankings(ExecContext* ctx, const NullPartitionResult& sorted,
const NullPlacement null_placement,
const RankOptions::Tiebreaker tiebreaker,
ValueSelector&& value_selector) {
const RankOptions::Tiebreaker tiebreaker) {
auto length = sorted.overall_end() - sorted.overall_begin();
ARROW_ASSIGN_OR_RAISE(auto rankings,
MakeMutableUInt64Array(length, ctx->memory_pool()));
auto out_begin = rankings->GetMutableValues<uint64_t>(1);
uint64_t rank;

auto is_duplicate = [](uint64_t index) { return (index & kDuplicateMask) != 0; };
auto original_index = [](uint64_t index) { return index & ~kDuplicateMask; };

switch (tiebreaker) {
case RankOptions::Dense: {
T curr_value, prev_value{};
rank = 0;

if (null_placement == NullPlacement::AtStart && sorted.null_count() > 0) {
rank++;
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
out_begin[*it] = rank;
}
}

for (auto it = sorted.non_nulls_begin; it < sorted.non_nulls_end; it++) {
curr_value = value_selector(*it);
if (it == sorted.non_nulls_begin || curr_value != prev_value) {
rank++;
}

out_begin[*it] = rank;
prev_value = curr_value;
}

if (null_placement == NullPlacement::AtEnd) {
rank++;
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
out_begin[*it] = rank;
for (auto it = sorted.overall_begin(); it < sorted.overall_end(); ++it) {
if (!is_duplicate(*it)) {
++rank;
}
out_begin[original_index(*it)] = rank;
}
break;
}
Expand All @@ -80,62 +95,27 @@ Result<Datum> CreateRankings(ExecContext* ctx, const NullPartitionResult& sorted
}

case RankOptions::Min: {
T curr_value, prev_value{};
rank = 0;

if (null_placement == NullPlacement::AtStart) {
rank++;
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
out_begin[*it] = rank;
}
}

for (auto it = sorted.non_nulls_begin; it < sorted.non_nulls_end; it++) {
curr_value = value_selector(*it);
if (it == sorted.non_nulls_begin || curr_value != prev_value) {
for (auto it = sorted.overall_begin(); it < sorted.overall_end(); ++it) {
if (!is_duplicate(*it)) {
rank = (it - sorted.overall_begin()) + 1;
}
out_begin[*it] = rank;
prev_value = curr_value;
}

if (null_placement == NullPlacement::AtEnd) {
rank = sorted.non_null_count() + 1;
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
out_begin[*it] = rank;
}
out_begin[original_index(*it)] = rank;
}
break;
}

case RankOptions::Max: {
// The algorithm for Max is just like Min, but in reverse order.
T curr_value, prev_value{};
rank = length;

if (null_placement == NullPlacement::AtEnd) {
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
out_begin[*it] = rank;
}
}

for (auto it = sorted.non_nulls_end - 1; it >= sorted.non_nulls_begin; it--) {
curr_value = value_selector(*it);

if (it == sorted.non_nulls_end - 1 || curr_value != prev_value) {
rank = (it - sorted.overall_begin()) + 1;
for (auto it = sorted.overall_end() - 1; it >= sorted.overall_begin(); --it) {
out_begin[original_index(*it)] = rank;
// If the current index isn't marked as duplicate, then it's the last
// tie in a row (since we iterate in reverse order), so update rank
// for the next row of ties.
if (!is_duplicate(*it)) {
rank = it - sorted.overall_begin();
}
out_begin[*it] = rank;
prev_value = curr_value;
}

if (null_placement == NullPlacement::AtStart) {
rank = sorted.null_count();
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
out_begin[*it] = rank;
}
}

break;
}
}
Expand Down Expand Up @@ -212,8 +192,11 @@ class Ranker<Array> : public RankerMixin<Array, Ranker<Array>> {
auto value_selector = [&array](int64_t index) {
return GetView::LogicalValue(array.GetView(index));
};
ARROW_ASSIGN_OR_RAISE(*output_, CreateRankings(ctx_, sorted, null_placement_,
tiebreaker_, value_selector));
if (NeedsDuplicates(tiebreaker_)) {
MarkDuplicates(sorted, value_selector);
}
ARROW_ASSIGN_OR_RAISE(*output_,
CreateRankings(ctx_, sorted, null_placement_, tiebreaker_));

return Status::OK();
}
Expand Down Expand Up @@ -242,8 +225,11 @@ class Ranker<ChunkedArray> : public RankerMixin<ChunkedArray, Ranker<ChunkedArra
auto value_selector = [resolver = ChunkedArrayResolver(span(arrays))](int64_t index) {
return resolver.Resolve(index).Value<InType>();
};
ARROW_ASSIGN_OR_RAISE(*output_, CreateRankings(ctx_, sorted, null_placement_,
tiebreaker_, value_selector));
if (NeedsDuplicates(tiebreaker_)) {
MarkDuplicates(sorted, value_selector);
}
ARROW_ASSIGN_OR_RAISE(*output_,
CreateRankings(ctx_, sorted, null_placement_, tiebreaker_));

return Status::OK();
}
Expand Down
8 changes: 2 additions & 6 deletions cpp/src/arrow/compute/kernels/vector_sort_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@
#include "arrow/type.h"
#include "arrow/type_traits.h"

namespace arrow {
namespace compute {
namespace internal {
namespace arrow::compute::internal {

// Visit all physical types for which sorting is implemented.
#define VISIT_SORTABLE_PHYSICAL_TYPES(VISIT) \
Expand Down Expand Up @@ -853,6 +851,4 @@ inline Result<std::shared_ptr<ArrayData>> MakeMutableUInt64Array(
return ArrayData::Make(uint64(), length, {nullptr, std::move(data)}, /*null_count=*/0);
}

} // namespace internal
} // namespace compute
} // namespace arrow
} // namespace arrow::compute::internal

0 comments on commit 34a9a86

Please sign in to comment.