Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
pitrou committed Jan 9, 2025
1 parent 2b5f56c commit 0219c1e
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 158 deletions.
155 changes: 71 additions & 84 deletions cpp/src/arrow/compute/kernels/vector_rank.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,114 +28,95 @@ namespace {
// ----------------------------------------------------------------------
// Rank implementation

template <typename ValueSelector,
typename T = std::decay_t<std::invoke_result_t<ValueSelector, int64_t>>>
constexpr uint64_t kDuplicateMask = 1ULL << 63;

bool NeedsDuplicates(RankOptions::Tiebreaker tiebreaker) {
return tiebreaker != RankOptions::First;
}

template <typename ValueSelector>
void MarkDuplicates(const NullPartitionResult& sorted, ValueSelector&& value_selector) {
using T = decltype(value_selector(int64_t{}));

// Process non-nulls
if (sorted.non_nulls_end != sorted.non_nulls_begin) {
auto it = sorted.non_nulls_begin;
T prev_value = value_selector(*it);
while (++it < sorted.non_nulls_end) {
T curr_value = value_selector(*it);
if (curr_value == prev_value) {
*it |= kDuplicateMask;
}
prev_value = curr_value;
}
}

// Process nulls
if (sorted.nulls_end != sorted.nulls_begin) {
// TODO this should be able to distinguish between NaNs and real nulls (GH-45193)
auto it = sorted.nulls_begin;
while (++it < sorted.nulls_end) {
*it |= kDuplicateMask;
}
}
}

Result<Datum> CreateRankings(ExecContext* ctx, const NullPartitionResult& sorted,
const NullPlacement null_placement,
const RankOptions::Tiebreaker tiebreaker,
ValueSelector&& value_selector) {
const RankOptions::Tiebreaker tiebreaker) {
auto length = sorted.overall_end() - sorted.overall_begin();
ARROW_ASSIGN_OR_RAISE(auto rankings,
MakeMutableUInt64Array(length, ctx->memory_pool()));
auto out_begin = rankings->GetMutableValues<uint64_t>(1);
uint64_t rank;

auto is_duplicate = [](uint64_t index) { return (index & kDuplicateMask) != 0; };
auto original_index = [](uint64_t index) { return index & ~kDuplicateMask; };

switch (tiebreaker) {
case RankOptions::Dense: {
T curr_value, prev_value{};
rank = 0;

if (null_placement == NullPlacement::AtStart && sorted.null_count() > 0) {
rank++;
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
out_begin[*it] = rank;
}
}

for (auto it = sorted.non_nulls_begin; it < sorted.non_nulls_end; it++) {
curr_value = value_selector(*it);
if (it == sorted.non_nulls_begin || curr_value != prev_value) {
rank++;
}

out_begin[*it] = rank;
prev_value = curr_value;
}

if (null_placement == NullPlacement::AtEnd) {
rank++;
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
out_begin[*it] = rank;
for (auto it = sorted.overall_begin(); it < sorted.overall_end(); ++it) {
if (!is_duplicate(*it)) {
++rank;
}
out_begin[original_index(*it)] = rank;
}
break;
}

case RankOptions::First: {
rank = 0;
for (auto it = sorted.overall_begin(); it < sorted.overall_end(); it++) {
// No duplicate marks expected for RankOptions::First
DCHECK(!is_duplicate(*it));
out_begin[*it] = ++rank;
}
break;
}

case RankOptions::Min: {
T curr_value, prev_value{};
rank = 0;

if (null_placement == NullPlacement::AtStart) {
rank++;
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
out_begin[*it] = rank;
}
}

for (auto it = sorted.non_nulls_begin; it < sorted.non_nulls_end; it++) {
curr_value = value_selector(*it);
if (it == sorted.non_nulls_begin || curr_value != prev_value) {
for (auto it = sorted.overall_begin(); it < sorted.overall_end(); ++it) {
if (!is_duplicate(*it)) {
rank = (it - sorted.overall_begin()) + 1;
}
out_begin[*it] = rank;
prev_value = curr_value;
}

if (null_placement == NullPlacement::AtEnd) {
rank = sorted.non_null_count() + 1;
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
out_begin[*it] = rank;
}
out_begin[original_index(*it)] = rank;
}
break;
}

case RankOptions::Max: {
// The algorithm for Max is just like Min, but in reverse order.
T curr_value, prev_value{};
rank = length;

if (null_placement == NullPlacement::AtEnd) {
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
out_begin[*it] = rank;
}
}

for (auto it = sorted.non_nulls_end - 1; it >= sorted.non_nulls_begin; it--) {
curr_value = value_selector(*it);

if (it == sorted.non_nulls_end - 1 || curr_value != prev_value) {
rank = (it - sorted.overall_begin()) + 1;
}
out_begin[*it] = rank;
prev_value = curr_value;
}

if (null_placement == NullPlacement::AtStart) {
rank = sorted.null_count();
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
out_begin[*it] = rank;
for (auto it = sorted.overall_end() - 1; it >= sorted.overall_begin(); --it) {
out_begin[original_index(*it)] = rank;
// If the current index isn't marked as duplicate, then it's the last
// tie in a row (since we iterate in reverse order), so update rank
// for the next row of ties.
if (!is_duplicate(*it)) {
rank = it - sorted.overall_begin();
}
}

break;
}
}
Expand Down Expand Up @@ -209,11 +190,14 @@ class Ranker<Array> : public RankerMixin<Array, Ranker<Array>> {
array_sorter(indices_begin_, indices_end_, array, 0,
ArraySortOptions(order_, null_placement_), ctx_));

auto value_selector = [&array](int64_t index) {
return GetView::LogicalValue(array.GetView(index));
};
ARROW_ASSIGN_OR_RAISE(*output_, CreateRankings(ctx_, sorted, null_placement_,
tiebreaker_, value_selector));
if (NeedsDuplicates(tiebreaker_)) {
auto value_selector = [&array](int64_t index) {
return GetView::LogicalValue(array.GetView(index));
};
MarkDuplicates(sorted, value_selector);
}
ARROW_ASSIGN_OR_RAISE(*output_,
CreateRankings(ctx_, sorted, null_placement_, tiebreaker_));

return Status::OK();
}
Expand All @@ -238,13 +222,16 @@ class Ranker<ChunkedArray> : public RankerMixin<ChunkedArray, Ranker<ChunkedArra
SortChunkedArray(ctx_, indices_begin_, indices_end_, physical_type_,
physical_chunks_, order_, null_placement_));

const auto arrays = GetArrayPointers(physical_chunks_);
auto value_selector = [resolver = ChunkedArrayResolver(span(arrays))](int64_t index) {
return resolver.Resolve(index).Value<InType>();
};
ARROW_ASSIGN_OR_RAISE(*output_, CreateRankings(ctx_, sorted, null_placement_,
tiebreaker_, value_selector));

if (NeedsDuplicates(tiebreaker_)) {
const auto arrays = GetArrayPointers(physical_chunks_);
auto value_selector = [resolver =
ChunkedArrayResolver(span(arrays))](int64_t index) {
return resolver.Resolve(index).Value<InType>();
};
MarkDuplicates(sorted, value_selector);
}
ARROW_ASSIGN_OR_RAISE(*output_,
CreateRankings(ctx_, sorted, null_placement_, tiebreaker_));
return Status::OK();
}

Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/compute/kernels/vector_sort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class ChunkedArraySorter : public TypeVisitor {
CompressedChunkLocation* nulls_middle,
CompressedChunkLocation* nulls_end,
CompressedChunkLocation* temp_indices, int64_t null_count) {
if (has_null_like_values<typename ArrayType::TypeClass>::value) {
if (has_null_like_values<typename ArrayType::TypeClass>()) {
PartitionNullsOnly<StablePartitioner>(nulls_begin, nulls_end, arrays,
null_count, null_placement_);
}
Expand Down Expand Up @@ -781,7 +781,7 @@ class TableSorter {
CompressedChunkLocation* nulls_middle,
CompressedChunkLocation* nulls_end,
CompressedChunkLocation* temp_indices, int64_t null_count) {
if constexpr (has_null_like_values<ArrowType>::value) {
if constexpr (has_null_like_values<ArrowType>()) {
// Merge rows with a null or a null-like in the first sort key
auto& comparator = comparator_;
const auto& first_sort_key = sort_keys_[0];
Expand Down
99 changes: 27 additions & 72 deletions cpp/src/arrow/compute/kernels/vector_sort_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@
#include "arrow/type.h"
#include "arrow/type_traits.h"

namespace arrow {
namespace compute {
namespace internal {
namespace arrow::compute::internal {

// Visit all physical types for which sorting is implemented.
#define VISIT_SORTABLE_PHYSICAL_TYPES(VISIT) \
Expand Down Expand Up @@ -71,49 +69,17 @@ struct StablePartitioner {
}
};

template <typename TypeClass, typename Enable = void>
struct NullTraits {
using has_null_like_values = std::false_type;
};

template <typename TypeClass>
struct NullTraits<TypeClass, enable_if_physical_floating_point<TypeClass>> {
using has_null_like_values = std::true_type;
};

template <typename TypeClass>
using has_null_like_values = typename NullTraits<TypeClass>::has_null_like_values;
constexpr bool has_null_like_values() {
return is_physical_floating(TypeClass::type_id);
}

// Compare two values, taking NaNs into account

template <typename Type, typename Enable = void>
struct ValueComparator;

template <typename Type>
struct ValueComparator<Type, enable_if_t<!has_null_like_values<Type>::value>> {
template <typename Value>
static int Compare(const Value& left, const Value& right, SortOrder order,
NullPlacement null_placement) {
int compared;
if (left == right) {
compared = 0;
} else if (left > right) {
compared = 1;
} else {
compared = -1;
}
if (order == SortOrder::Descending) {
compared = -compared;
}
return compared;
}
};

template <typename Type>
struct ValueComparator<Type, enable_if_t<has_null_like_values<Type>::value>> {
template <typename Value>
static int Compare(const Value& left, const Value& right, SortOrder order,
NullPlacement null_placement) {
template <typename Type, typename Value>
int CompareTypeValues(Value&& left, Value&& right, SortOrder order,
NullPlacement null_placement) {
if constexpr (has_null_like_values<Type>()) {
const bool is_nan_left = std::isnan(left);
const bool is_nan_right = std::isnan(right);
if (is_nan_left && is_nan_right) {
Expand All @@ -123,25 +89,19 @@ struct ValueComparator<Type, enable_if_t<has_null_like_values<Type>::value>> {
} else if (is_nan_right) {
return null_placement == NullPlacement::AtStart ? 1 : -1;
}
int compared;
if (left == right) {
compared = 0;
} else if (left > right) {
compared = 1;
} else {
compared = -1;
}
if (order == SortOrder::Descending) {
compared = -compared;
}
return compared;
}
};

template <typename Type, typename Value>
int CompareTypeValues(const Value& left, const Value& right, SortOrder order,
NullPlacement null_placement) {
return ValueComparator<Type>::Compare(left, right, order, null_placement);
int compared;
if (left == right) {
compared = 0;
} else if (left > right) {
compared = 1;
} else {
compared = -1;
}
if (order == SortOrder::Descending) {
compared = -compared;
}
return compared;
}

template <typename IndexType>
Expand Down Expand Up @@ -238,17 +198,15 @@ NullPartitionResult PartitionNullsOnly(uint64_t* indices_begin, uint64_t* indice
//
// `offset` is used when this is called on a chunk of a chunked array
template <typename ArrayType, typename Partitioner>
enable_if_t<!has_null_like_values<typename ArrayType::TypeClass>::value,
NullPartitionResult>
enable_if_t<!has_null_like_values<typename ArrayType::TypeClass>(), NullPartitionResult>
PartitionNullLikes(uint64_t* indices_begin, uint64_t* indices_end,
const ArrayType& values, int64_t offset,
NullPlacement null_placement) {
return NullPartitionResult::NoNulls(indices_begin, indices_end, null_placement);
}

template <typename ArrayType, typename Partitioner>
enable_if_t<has_null_like_values<typename ArrayType::TypeClass>::value,
NullPartitionResult>
enable_if_t<has_null_like_values<typename ArrayType::TypeClass>(), NullPartitionResult>
PartitionNullLikes(uint64_t* indices_begin, uint64_t* indices_end,
const ArrayType& values, int64_t offset,
NullPlacement null_placement) {
Expand Down Expand Up @@ -345,18 +303,17 @@ ChunkedNullPartitionResult PartitionNullsOnly(CompressedChunkLocation* locations
}

template <typename ArrayType, typename Partitioner>
enable_if_t<!has_null_like_values<typename ArrayType::TypeClass>::value,
NullPartitionResult>
enable_if_t<!has_null_like_values<typename ArrayType::TypeClass>(), NullPartitionResult>
PartitionNullLikes(uint64_t* indices_begin, uint64_t* indices_end,
const ChunkedArrayResolver& resolver, NullPlacement null_placement) {
return NullPartitionResult::NoNulls(indices_begin, indices_end, null_placement);
}

template <typename ArrayType, typename Partitioner,
typename TypeClass = typename ArrayType::TypeClass>
enable_if_t<has_null_like_values<TypeClass>::value, NullPartitionResult>
PartitionNullLikes(uint64_t* indices_begin, uint64_t* indices_end,
const ChunkedArrayResolver& resolver, NullPlacement null_placement) {
enable_if_t<has_null_like_values<TypeClass>(), NullPartitionResult> PartitionNullLikes(
uint64_t* indices_begin, uint64_t* indices_end, const ChunkedArrayResolver& resolver,
NullPlacement null_placement) {
Partitioner partitioner;
if (null_placement == NullPlacement::AtStart) {
auto null_likes_end = partitioner(indices_begin, indices_end, [&](uint64_t ind) {
Expand Down Expand Up @@ -853,6 +810,4 @@ inline Result<std::shared_ptr<ArrayData>> MakeMutableUInt64Array(
return ArrayData::Make(uint64(), length, {nullptr, std::move(data)}, /*null_count=*/0);
}

} // namespace internal
} // namespace compute
} // namespace arrow
} // namespace arrow::compute::internal
Loading

0 comments on commit 0219c1e

Please sign in to comment.