diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 6e2294371e7a6..f4da6552d20cf 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -756,6 +756,7 @@ if(ARROW_COMPUTE) compute/kernels/aggregate_tdigest.cc compute/kernels/aggregate_var_std.cc compute/kernels/hash_aggregate.cc + compute/kernels/pivot_internal.cc compute/kernels/scalar_arithmetic.cc compute/kernels/scalar_boolean.cc compute/kernels/scalar_compare.cc diff --git a/cpp/src/arrow/acero/groupby_aggregate_node.cc b/cpp/src/arrow/acero/groupby_aggregate_node.cc index 06b034ab2d459..2beef360b45d4 100644 --- a/cpp/src/arrow/acero/groupby_aggregate_node.cc +++ b/cpp/src/arrow/acero/groupby_aggregate_node.cc @@ -282,6 +282,11 @@ Status GroupByNode::Merge() { DCHECK(state0->agg_states[span_i]); batch_ctx.SetState(state0->agg_states[span_i].get()); + // XXX this resizes each KernelState (state0->agg_states[span_i]) multiple times. + // An alternative would be a two-pass algorithm: + // 1. Compute all transpositions (one per local state) and the final number of + // groups. + // 2. Process all agg kernels, resizing each KernelState only once. RETURN_NOT_OK( agg_kernels_[span_i]->resize(&batch_ctx, state0->grouper->num_groups())); RETURN_NOT_OK(agg_kernels_[span_i]->merge( diff --git a/cpp/src/arrow/compute/api_aggregate.h b/cpp/src/arrow/compute/api_aggregate.h index 2e5210b073ee4..886d992640c5f 100644 --- a/cpp/src/arrow/compute/api_aggregate.h +++ b/cpp/src/arrow/compute/api_aggregate.h @@ -175,6 +175,75 @@ class ARROW_EXPORT TDigestOptions : public FunctionOptions { uint32_t min_count; }; +/// \brief Control Pivot kernel behavior +/// +/// These options apply to the "pivot" (TODO) and "hash_pivot" (TODO) functions. +/// +/// Constraints: +/// - The corresponding `Aggregate::target` must have two FieldRef elements; +/// the first one points to the pivot key column, the second points to the +/// pivoted data column. +/// - The pivot key column must be string-like; its values will be matched +/// against `key_names` in order to dispatch the pivoted data into the +/// output. +/// +/// "hash_pivot" example +/// -------------------- +/// +/// Assuming the following input with schema +/// `{"group": int32, "key": utf8, "value": int16}`: +/// ``` +/// group | key | value +/// ----------------------------- +/// 1 | height | 11 +/// 1 | width | 12 +/// 2 | width | 13 +/// 3 | height | 14 +/// 3 | depth | 15 +/// ``` +/// and the following settings: +/// - a hash grouping key "group" +/// - Aggregate( +/// .function = "hash_pivot", +/// .options = PivotOptions(.key_names = {"height", "width"}), +/// .target = {"key", "value"}, +/// .name = {"props"}) +/// +/// then the output will have the schema +/// `{"group": int32, "props": struct{"height": int16, "width": int16}}` +/// and the following value: +/// ``` +/// group | props +/// | height | width +/// ----------------------------- +/// 1 | 11 | 12 +/// 2 | null | 13 +/// 3 | 14 | null +/// ``` +class ARROW_EXPORT PivotOptions : public FunctionOptions { + public: + // Configure the behavior of pivot keys not in `key_names` + enum UnexpectedKeyBehavior { + // Unexpected pivot keys are ignored silently + kIgnore, + // Unexpected pivot keys return a KeyError + kRaise + }; + // TODO should duplicate key behavior be configurable as well? + + explicit PivotOptions(std::vector key_names, + UnexpectedKeyBehavior unexpected_key_behavior = kIgnore); + // Default constructor for serialization + PivotOptions(); + static constexpr char const kTypeName[] = "PivotOptions"; + static PivotOptions Defaults() { return PivotOptions{}; } + + // The values expected in the pivot key column + std::vector key_names; + // The behavior when pivot keys not in `key_names` are encountered + UnexpectedKeyBehavior unexpected_key_behavior = kIgnore; +}; + /// \brief Control Index kernel behavior class ARROW_EXPORT IndexOptions : public FunctionOptions { public: diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index 21b7bd9bf6632..73fbb157c0395 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -26,6 +26,7 @@ #include "arrow/array/builder_nested.h" #include "arrow/array/builder_primitive.h" +#include "arrow/array/concatenate.h" #include "arrow/buffer_builder.h" #include "arrow/compute/api_aggregate.h" #include "arrow/compute/api_vector.h" @@ -33,6 +34,7 @@ #include "arrow/compute/kernels/aggregate_internal.h" #include "arrow/compute/kernels/aggregate_var_std_internal.h" #include "arrow/compute/kernels/common_internal.h" +#include "arrow/compute/kernels/pivot_internal.h" #include "arrow/compute/kernels/util_internal.h" #include "arrow/compute/row/grouper.h" #include "arrow/compute/row/row_encoder_internal.h" @@ -40,6 +42,7 @@ #include "arrow/stl_allocator.h" #include "arrow/type_traits.h" #include "arrow/util/bit_run_reader.h" +#include "arrow/util/bit_util.h" #include "arrow/util/bitmap_ops.h" #include "arrow/util/bitmap_writer.h" #include "arrow/util/checked_cast.h" @@ -47,6 +50,7 @@ #include "arrow/util/int128_internal.h" #include "arrow/util/int_util_overflow.h" #include "arrow/util/ree_util.h" +#include "arrow/util/span.h" #include "arrow/util/task_group.h" #include "arrow/util/tdigest.h" #include "arrow/util/thread_pool.h" @@ -56,6 +60,7 @@ namespace arrow { using internal::checked_cast; using internal::FirstTimeBitmapWriter; +using util::span; namespace compute { namespace internal { @@ -3319,9 +3324,289 @@ struct GroupedListFactory { HashAggregateKernel kernel; InputType argument_type; }; -} // namespace -namespace { +// ---------------------------------------------------------------------- +// Pivot implementation + +struct GroupedPivotAccumulator { + Status Init(ExecContext* ctx, std::shared_ptr value_type, + const PivotOptions* options) { + ctx_ = ctx; + value_type_ = std::move(value_type); + num_keys_ = static_cast(options->key_names.size()); + num_groups_ = 0; + columns_.resize(num_keys_); + return Status::OK(); + } + + Status Consume(span groups, span keys, + const ArraySpan& values) { + using TakeIndexType = UInt32Type; + using TakeIndex = typename TypeTraits::CType; + + DCHECK_EQ(groups.size(), keys.size()); + DCHECK_EQ(groups.size(), static_cast(values.length)); + // TODO allocate single buffers and slice them for individual columns + std::vector> take_indices(num_keys_); + std::vector> take_bitmaps(num_keys_); + for (int i = 0; i < num_keys_; ++i) { + ARROW_ASSIGN_OR_RAISE( + take_indices[i], + AllocateBuffer(num_groups_ * sizeof(TakeIndex), ctx_->memory_pool())); + ARROW_ASSIGN_OR_RAISE(take_bitmaps[i], + AllocateEmptyBitmap(num_groups_, ctx_->memory_pool())); + } + // Populate the indices to take from for each grouped column + for (int64_t i = 0; i < values.length; ++i) { + // TODO cache the mutable_data() + const PivotKeyIndex key = keys[i]; + const uint32_t group = groups[i]; + if (key != kNullPivotKey) { + DCHECK_LT(static_cast(key), num_keys_); + // XXX check for already existing entry? + // For row #group in column #key, we are going to take the value at index #i + bit_util::SetBit(take_bitmaps[key]->mutable_data(), group); + take_indices[key]->mutable_data_as()[group] = i; + } + } + // Compute the grouped columns for this batch + auto values_data = values.ToArrayData(); + ArrayVector new_columns(num_keys_); + for (int i = 0; i < num_keys_; ++i) { + auto indices_data = + ArrayData::Make(TypeTraits::type_singleton(), num_groups_, + {std::move(take_bitmaps[i]), std::move(take_indices[i])}); + ARROW_ASSIGN_OR_RAISE(Datum grouped_column, Take(values_data, indices_data, + TakeOptions::Defaults(), ctx_)); + new_columns[i] = grouped_column.make_array(); + } + return MergeColumns(std::move(new_columns)); + } + + Status Consume(span groups, const PivotKeyIndex key, + const ArraySpan& values) { + using TakeIndexType = UInt32Type; + using TakeIndex = typename TypeTraits::CType; + + if (key == kNullPivotKey) { + // Nothing to update + return Status::OK(); + } + DCHECK_LT(static_cast(key), num_keys_); + + // Only the column #key needs to be updated + ARROW_ASSIGN_OR_RAISE( + auto take_indices, + AllocateBuffer(num_groups_ * sizeof(TakeIndex), ctx_->memory_pool())); + ARROW_ASSIGN_OR_RAISE(auto take_bitmap, + AllocateEmptyBitmap(num_groups_, ctx_->memory_pool())); + + DCHECK_EQ(groups.size(), static_cast(values.length)); + for (int64_t i = 0; i < values.length; ++i) { + const uint32_t group = groups[i]; + // XXX check for already existing entry? + bit_util::SetBit(take_bitmap->mutable_data(), group); + take_indices->mutable_data_as()[group] = i; + } + auto values_data = values.ToArrayData(); + auto indices_data = + ArrayData::Make(TypeTraits::type_singleton(), num_groups_, + {std::move(take_bitmap), std::move(take_indices)}); + ARROW_ASSIGN_OR_RAISE(Datum grouped_column, + Take(values_data, indices_data, TakeOptions::Defaults(), ctx_)); + return MergeColumn(&columns_[key], grouped_column.make_array()); + } + + Status Resize(int64_t new_num_groups) { return ResizeColumns(new_num_groups); } + + Status Merge(GroupedPivotAccumulator&& other, const ArrayData& group_id_mapping) { + // TODO need "scatter" function (inverse of "take") + auto scatter_indices = group_id_mapping.Copy(); + auto scatter_column = + [&](const std::shared_ptr& column) -> Result> { + ScatterOptions options(/*max_index=*/num_groups_ + 1); + return CallFunction("scatter", {column, scatter_indices}, &options, ctx_); + }; + return MergeColumns(std::move(other.columns_), std::move(scatter_column)); + } + + Result Finalize() { + // Ensure that columns are allocated even if num_groups_ == 0 + RETURN_NOT_OK(ResizeColumns(num_groups_)); + return std::move(columns_); + } + + protected: + Status ResizeColumns(int64_t new_num_groups) { + if (new_num_groups == num_groups_ && num_groups_ != 0) { + return Status::OK(); + } + ARROW_ASSIGN_OR_RAISE( + auto array_suffix, + MakeArrayOfNull(value_type_, new_num_groups - num_groups_, ctx_->memory_pool())); + for (auto& column : columns_) { + if (num_groups_ != 0) { + DCHECK_NE(column, nullptr); + ARROW_ASSIGN_OR_RAISE( + column, Concatenate({std::move(column), array_suffix}, ctx_->memory_pool())); + } else { + column = array_suffix; + } + DCHECK_EQ(column->length(), new_num_groups); + } + num_groups_ = new_num_groups; + return Status::OK(); + } + + using ColumnTransform = + std::function>(const std::shared_ptr&)>; + + Status MergeColumns(ArrayVector&& other_columns, + const ColumnTransform& transform = {}) { + DCHECK_EQ(columns_.size(), other_columns.size()); + for (int i = 0; i < num_keys_; ++i) { + if (other_columns[i]) { + RETURN_NOT_OK(MergeColumn(&columns_[i], std::move(other_columns[i]), transform)); + } + } + return Status::OK(); + } + + Status MergeColumn(std::shared_ptr* column, std::shared_ptr other_column, + const ColumnTransform& transform = {}) { + if (transform) { + ARROW_ASSIGN_OR_RAISE(other_column, transform(other_column)); + } + DCHECK_EQ(num_groups_, other_column->length()); + if (*column) { + int64_t expected_non_nulls = (num_groups_ - (*column)->null_count()) + + (num_groups_ - other_column->null_count()); + ARROW_ASSIGN_OR_RAISE(auto coalesced, + CallFunction("coalesce", {*column, other_column}, ctx_)); + if (expected_non_nulls != num_groups_ - coalesced.null_count()) { + return Status::Invalid( + "Encountered more than one value for the same grouped pivot key"); + } + *column = coalesced.make_array(); + } else { + *column = other_column; + } + return Status::OK(); + } + + ExecContext* ctx_; + std::shared_ptr value_type_; + int num_keys_; + int64_t num_groups_; + ArrayVector columns_; +}; + +struct GroupedPivotImpl : public GroupedAggregator { + Status Resize(int64_t new_num_groups) override { + num_groups_ = new_num_groups; + return accumulator_.Resize(new_num_groups); + } + + Status Merge(GroupedAggregator&& raw_other, + const ArrayData& group_id_mapping) override { + auto other = checked_cast(&raw_other); + return accumulator_.Merge(std::move(other->accumulator_), group_id_mapping); + } + + Status Consume(const ExecSpan& batch) override { + DCHECK_EQ(batch.values.size(), 3); + auto groups = batch[2].array.GetSpan(1, batch.length); + if (!batch[1].is_array()) { + return Status::NotImplemented("Consuming scalar pivot value"); + } + if (batch[0].is_array()) { + ARROW_ASSIGN_OR_RAISE(span keys, + key_mapper_->MapKeys(batch[0].array)); + return accumulator_.Consume(groups, keys, batch[1].array); + } else { + ARROW_ASSIGN_OR_RAISE(PivotKeyIndex key, key_mapper_->MapKey(*batch[0].scalar)); + return accumulator_.Consume(groups, key, batch[1].array); + } + } + + Result Finalize() override { + ARROW_ASSIGN_OR_RAISE(auto columns, accumulator_.Finalize()); + DCHECK_EQ(columns.size(), static_cast(out_struct_type_->num_fields())); + return std::make_shared(out_type_, num_groups_, std::move(columns), + /*null_bitmap=*/nullptr, + /*null_count=*/0); + } + + std::shared_ptr out_type() const override { return out_type_; } + + std::shared_ptr key_type_; + std::shared_ptr out_type_; + const StructType* out_struct_type_; + const PivotOptions* options_; + std::unique_ptr key_mapper_; + GroupedPivotAccumulator accumulator_; + int64_t num_groups_ = 0; +}; + +// TODO the template is unused +template +struct TypedGroupedPivotImpl : public GroupedPivotImpl { + Status Init(ExecContext* ctx, const KernelInitArgs& args) override { + DCHECK_EQ(args.inputs.size(), 3); + key_type_ = args.inputs[0].GetSharedPtr(); + options_ = checked_cast(args.options); + DCHECK_NE(options_, nullptr); + auto value_type = args.inputs[1].GetSharedPtr(); + FieldVector fields; + fields.reserve(options_->key_names.size()); + for (const auto& key_name : options_->key_names) { + fields.push_back(field(key_name, value_type)); + } + out_type_ = struct_(std::move(fields)); + out_struct_type_ = checked_cast(out_type_.get()); + ARROW_ASSIGN_OR_RAISE(key_mapper_, PivotKeyMapper::Make(*key_type_, options_)); + RETURN_NOT_OK(accumulator_.Init(ctx, value_type, options_)); + return Status::OK(); + } +}; + +// TODO simplify this away? +template +Result> GroupedPivotInit(KernelContext* ctx, + const KernelInitArgs& args) { + ARROW_ASSIGN_OR_RAISE(auto impl, + HashAggregateInit>(ctx, args)); + return impl; +} + +struct GroupedPivotFactory { + template + enable_if_base_binary Visit(const KeyType& type) { + // TODO replace Any() with a more selective matcher for the value type + auto sig = + KernelSignature::Make({type.id(), InputType::Any(), InputType(Type::UINT32)}, + OutputType(ResolveGroupOutputType)); + kernel = MakeKernel(std::move(sig), GroupedPivotInit); + return Status::OK(); + } + + Status Visit(const DataType& type) { + return Status::TypeError("Unsupported pivot key type: ", type); + } + + static Result Make( + const std::shared_ptr& pivot_key_type) { + GroupedPivotFactory factory; + RETURN_NOT_OK(VisitTypeInline(*pivot_key_type, &factory)); + return std::move(factory.kernel); + } + + HashAggregateKernel kernel; +}; + +// ---------------------------------------------------------------------- +// Docstrings + const FunctionDoc hash_count_doc{ "Count the number of null / non-null values in each group", ("By default, only non-null values are counted.\n" @@ -3456,6 +3741,18 @@ const FunctionDoc hash_one_doc{"Get one value from each group", const FunctionDoc hash_list_doc{"List all values in each group", ("Null values are also returned."), {"array", "group_id_array"}}; + +const FunctionDoc hash_pivot_doc{ + "Pivot values according to a pivot key column", + ("Output is a struct array with as many fields as `PivotOptions.key_names`.\n" + "All output struct fields have the same type as `pivot_values`.\n" + "Each pivot key decides in which output field the corresponding pivot value\n" + "is emitted. If a pivot key doesn't appear in a given group, null is emitted.\n" + "If a pivot key appears twice in a given group, KeyError is raised.\n" + "Behavior of unexpected pivot keys is controlled by PivotOptions."), + {"pivot_keys", "pivot_values", "group_id_array"}, + "PivotOptions"}; + } // namespace void RegisterHashAggregateBasic(FunctionRegistry* registry) { @@ -3705,6 +4002,14 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) { GroupedListFactory::Make, func.get())); DCHECK_OK(registry->AddFunction(std::move(func))); } + + { + auto func = std::make_shared("hash_pivot", Arity::Ternary(), + hash_pivot_doc); + DCHECK_OK( + AddHashAggKernels(BaseBinaryTypes(), GroupedPivotFactory::Make, func.get())); + DCHECK_OK(registry->AddFunction(std::move(func))); + } } } // namespace internal diff --git a/cpp/src/arrow/compute/kernels/pivot_internal.cc b/cpp/src/arrow/compute/kernels/pivot_internal.cc new file mode 100644 index 0000000000000..e9c86f6119e5a --- /dev/null +++ b/cpp/src/arrow/compute/kernels/pivot_internal.cc @@ -0,0 +1,152 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/kernels/pivot_internal.h" + +#include + +#include "arrow/compute/exec.h" +#include "arrow/compute/kernels/codegen_internal.h" +#include "arrow/scalar.h" +#include "arrow/type_traits.h" +#include "arrow/util/checked_cast.h" +#include "arrow/visit_type_inline.h" + +namespace arrow::compute::internal { + +using ::arrow::util::span; + +struct BasePivotKeyMapper : public PivotKeyMapper { + Status Init(const PivotOptions* options) override { + if (options->key_names.size() > static_cast(kMaxPivotKey) + 1) { + return Status::NotImplemented("Pivoting to more than ", + static_cast(kMaxPivotKey) + 1, + " columns: got ", options->key_names.size()); + } + key_name_map_.reserve(options->key_names.size()); + PivotKeyIndex index = 0; + for (const auto& key_name : options->key_names) { + bool inserted = + key_name_map_.try_emplace(std::string_view(key_name), index++).second; + if (!inserted) { + return Status::KeyError("Duplicate key name '", key_name, "' in PivotOptions"); + } + } + unexpected_key_behavior_ = options->unexpected_key_behavior; + return Status::OK(); + } + + protected: + Result KeyNotFound(std::string_view key_name) { + if (unexpected_key_behavior_ == PivotOptions::kIgnore) { + return kNullPivotKey; + } + return Status::KeyError("Unexpected pivot key: ", key_name); + } + + Result LookupKey(std::string_view key_name) { + const auto it = this->key_name_map_.find(key_name); + if (ARROW_PREDICT_FALSE(it == this->key_name_map_.end())) { + return KeyNotFound(key_name); + } else { + return it->second; + } + } + + static constexpr int kBatchLength = 512; + // The strings backing the string_views should be kept alive by PivotOptions. + std::unordered_map key_name_map_; + PivotOptions::UnexpectedKeyBehavior unexpected_key_behavior_; + TypedBufferBuilder key_indices_buffer_; +}; + +template +struct TypedPivotKeyMapper : public BasePivotKeyMapper { + Result> MapKeys(const ArraySpan& array) override { + RETURN_NOT_OK(this->key_indices_buffer_.Reserve(array.length)); + PivotKeyIndex* key_indices = this->key_indices_buffer_.mutable_data(); + int64_t i = 0; + RETURN_NOT_OK(VisitArrayValuesInline( + array, + [&](std::string_view key_name) { + ARROW_ASSIGN_OR_RAISE(key_indices[i], LookupKey(key_name)); + ++i; + return Status::OK(); + }, + [&]() { return Status::KeyError("key name cannot be null"); })); + return span(key_indices, array.length); + } + + Result MapKey(const Scalar& scalar) override { + const auto& binary_scalar = checked_cast(scalar); + return LookupKey(binary_scalar.view()); + } +}; + +struct PivotKeyMapperFactory { + template + Status Visit(const T& key_type) { + if constexpr (is_base_binary_like(T::type_id)) { + instance = std::make_unique>(); + return instance->Init(options); + } + return Status::NotImplemented("Pivot key type: ", key_type); + } + + const PivotOptions* options; + std::unique_ptr instance{}; +}; + +Result> PivotKeyMapper::Make( + const DataType& key_type, const PivotOptions* options) { + PivotKeyMapperFactory factory{options}; + RETURN_NOT_OK(VisitTypeInline(key_type, &factory)); + return std::move(factory).instance; +} + +/* +TODO +would probably like to write: + +Result> PivotKeyMapper::Make(const DataType& key_type, + const PivotOptions* options) +{ std::unique_ptr instance; RETURN_NOT_OK(VisitTypeInline(key_type, +[&](auto key_type) { using T = std::decay_t; if constexpr +(is_base_binary_like(T::type_id)) { instance = std::make_unique>(); + return instance->Init(options); + } + return Status::NotImplemented("Pivot key type: ", key_type); + })); + return instance; +} + +or even: + +Result> PivotKeyMapper::Make(const DataType& key_type, + const PivotOptions* options) +{ return VisitTypeInline(key_type, [&](auto key_type) -> +Result> { using T = std::decay_t; if +constexpr (is_base_binary_like(T::type_id)) { auto instance = +std::make_unique>(); RETURN_NOT_OK(instance->Init(options)); return +instance; + } + return Status::NotImplemented("Pivot key type: ", key_type); + }); +} +*/ + +} // namespace arrow::compute::internal diff --git a/cpp/src/arrow/compute/kernels/pivot_internal.h b/cpp/src/arrow/compute/kernels/pivot_internal.h new file mode 100644 index 0000000000000..2921c976d1f91 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/pivot_internal.h @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/compute/api_aggregate.h" +#include "arrow/compute/type_fwd.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type_fwd.h" +#include "arrow/util/span.h" + +namespace arrow::compute::internal { + +using PivotKeyIndex = uint8_t; + +constexpr PivotKeyIndex kNullPivotKey = std::numeric_limits::max(); +constexpr PivotKeyIndex kMaxPivotKey = kNullPivotKey - 1; + +struct PivotKeyMapper { + virtual ~PivotKeyMapper() = default; + + virtual Status Init(const PivotOptions* options) = 0; + virtual Result<::arrow::util::span> MapKeys(const ArraySpan&) = 0; + virtual Result MapKey(const Scalar&) = 0; + + static Result> Make(const DataType& key_type, + const PivotOptions* options); +}; + +} // namespace arrow::compute::internal diff --git a/cpp/src/arrow/compute/type_fwd.h b/cpp/src/arrow/compute/type_fwd.h index 89f32ceb0f906..016d97a0dbc2b 100644 --- a/cpp/src/arrow/compute/type_fwd.h +++ b/cpp/src/arrow/compute/type_fwd.h @@ -40,6 +40,7 @@ class CastOptions; struct ExecBatch; class ExecContext; +struct ExecValue; class KernelContext; struct Kernel;