Skip to content

Commit

Permalink
apacheGH-45269: [C++][Compute] Add pivot function
Browse files Browse the repository at this point in the history
  • Loading branch information
pitrou committed Jan 16, 2025
1 parent e434536 commit b45d585
Show file tree
Hide file tree
Showing 9 changed files with 673 additions and 10 deletions.
1 change: 1 addition & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,7 @@ if(ARROW_COMPUTE)
compute/kernels/aggregate_tdigest.cc
compute/kernels/aggregate_var_std.cc
compute/kernels/hash_aggregate.cc
compute/kernels/pivot_internal.cc
compute/kernels/scalar_arithmetic.cc
compute/kernels/scalar_boolean.cc
compute/kernels/scalar_compare.cc
Expand Down
11 changes: 11 additions & 0 deletions cpp/src/arrow/acero/groupby_aggregate_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,13 @@ Result<AggregateNodeArgs<HashAggregateKernel>> GroupByNode::MakeAggregateNodeArg

// Find input field indices for aggregates
std::vector<std::vector<int>> agg_src_fieldsets(aggs.size());
// ARROW_LOG(INFO) << "input schema: " << input_schema->ToString();
for (size_t i = 0; i < aggs.size(); ++i) {
const auto& target_fieldset = aggs[i].target;
// ARROW_LOG(INFO) << "target #" << i << " has " << target_fieldset.size() << "
// targets";
for (const auto& target : target_fieldset) {
// ARROW_LOG(INFO) << " ... " << target.ToString();
ARROW_ASSIGN_OR_RAISE(auto match, target.FindOne(*input_schema));
agg_src_fieldsets[i].push_back(match[0]);
}
Expand All @@ -108,6 +112,8 @@ Result<AggregateNodeArgs<HashAggregateKernel>> GroupByNode::MakeAggregateNodeArg
std::vector<std::vector<TypeHolder>> agg_src_types(aggs.size());
for (size_t i = 0; i < aggs.size(); ++i) {
for (const auto& agg_src_field_id : agg_src_fieldsets[i]) {
// ARROW_LOG(INFO) << "target #" << i << " field = " <<
// input_schema->field(agg_src_field_id)->ToString();
agg_src_types[i].push_back(input_schema->field(agg_src_field_id)->type().get());
}
}
Expand Down Expand Up @@ -282,6 +288,11 @@ Status GroupByNode::Merge() {
DCHECK(state0->agg_states[span_i]);
batch_ctx.SetState(state0->agg_states[span_i].get());

// XXX this resizes each KernelState (state0->agg_states[span_i]) multiple times.
// An alternative would be a two-pass algorithm:
// 1. Compute all transpositions (one per local state) and the final number of
// groups.
// 2. Process all agg kernels, resizing each KernelState only once.
RETURN_NOT_OK(
agg_kernels_[span_i]->resize(&batch_ctx, state0->grouper->num_groups()));
RETURN_NOT_OK(agg_kernels_[span_i]->merge(
Expand Down
63 changes: 56 additions & 7 deletions cpp/src/arrow/acero/hash_aggregate_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,9 @@
#include "arrow/array/concatenate.h"
#include "arrow/chunked_array.h"
#include "arrow/compute/api_aggregate.h"
#include "arrow/compute/api_scalar.h"
#include "arrow/compute/api_vector.h"
#include "arrow/compute/cast.h"
#include "arrow/compute/exec.h"
#include "arrow/compute/exec_internal.h"
#include "arrow/compute/kernels/aggregate_internal.h"
#include "arrow/compute/kernels/codegen_internal.h"
#include "arrow/compute/registry.h"
#include "arrow/compute/row/grouper.h"
#include "arrow/compute/row/grouper_internal.h"
Expand All @@ -51,9 +47,7 @@
#include "arrow/type.h"
#include "arrow/type_traits.h"
#include "arrow/util/async_generator.h"
#include "arrow/util/bitmap_reader.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/int_util_overflow.h"
#include "arrow/util/key_value_metadata.h"
#include "arrow/util/logging.h"
#include "arrow/util/string.h"
Expand All @@ -65,7 +59,6 @@ using testing::HasSubstr;

namespace arrow {

using internal::BitmapReader;
using internal::checked_cast;
using internal::checked_pointer_cast;
using internal::ToChars;
Expand All @@ -76,6 +69,7 @@ using compute::default_exec_context;
using compute::ExecSpan;
using compute::FunctionOptions;
using compute::Grouper;
using compute::PivotOptions;
using compute::RowSegmenter;
using compute::ScalarAggregateOptions;
using compute::Segment;
Expand Down Expand Up @@ -1489,6 +1483,7 @@ class GroupBy : public ::testing::TestWithParam<GroupByFunction> {
return acero::GroupByTest(GetParam(), arguments, keys, aggregates, use_threads);
}

// TODO why not rename this to GroupByTest?
Result<Datum> AltGroupBy(const std::vector<Datum>& arguments,
const std::vector<Datum>& keys,
const std::vector<Datum>& segment_keys,
Expand Down Expand Up @@ -5269,6 +5264,60 @@ TEST_P(GroupBy, OnlyKeys) {
}
}

// TODO unused key_names
// TODO unexpected key_names
// TODO duplicate values
// TODO duplicate keys
// TODO nulls in keys
// TODO nulls in values
TEST_P(GroupBy, PivotFloatValues) {
auto value_type = float32();
for (bool use_threads : {false, true}) {
for (const auto& key_type : BaseBinaryTypes()) {
ARROW_SCOPED_TRACE("key_type = ", *key_type);
ARROW_SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");

auto table =
TableFromJSON(schema({field("group_key", int64()), field("key", utf8()),
field("value", value_type)}),
{R"([
[1, "width", 10.5],
[2, "width", 11.5]
])",
R"([
[2, "height", 12.5],
[3, "width", 13.5],
[1, "height", 14.5]
])"});

auto options =
std::make_shared<PivotOptions>(PivotOptions(/*key_names=*/{"height", "width"}));
Aggregate agg{"hash_pivot", options,
/*target=*/std::vector<FieldRef>{"agg_0", "agg_1"}, /*name=*/"out"};
ASSERT_OK_AND_ASSIGN(
Datum aggregated_and_grouped,
AltGroupBy({table->GetColumnByName("key"), table->GetColumnByName("value")},
{table->GetColumnByName("group_key")},
/*segment_keys=*/{}, {agg}, use_threads));
ValidateOutput(aggregated_and_grouped);

AssertDatumsEqual(
ArrayFromJSON(struct_({
field("key_0", int64()),
field("out", struct_({field("height", value_type),
field("width", value_type)})),
}),
R"([
[1, {"height": 14.5, "width": 10.5} ],
[2, {"height": 12.5, "width": 11.5} ],
[3, {"height": null, "width": 13.5} ]
])"),
aggregated_and_grouped,
/*verbose=*/true);
}
}
}

INSTANTIATE_TEST_SUITE_P(GroupBy, GroupBy, ::testing::Values(RunGroupByImpl));

class SegmentedScalarGroupBy : public GroupBy {};
Expand Down
30 changes: 29 additions & 1 deletion cpp/src/arrow/compute/api_aggregate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
#include "arrow/util/logging.h"

namespace arrow {

namespace internal {

template <>
struct EnumTraits<compute::CountOptions::CountMode>
: BasicEnumTraits<compute::CountOptions::CountMode, compute::CountOptions::ONLY_VALID,
Expand Down Expand Up @@ -67,6 +67,23 @@ struct EnumTraits<compute::QuantileOptions::Interpolation>
return "<INVALID>";
}
};

template <>
struct EnumTraits<compute::PivotOptions::UnexpectedKeyBehavior>
: BasicEnumTraits<compute::PivotOptions::UnexpectedKeyBehavior,
compute::PivotOptions::kIgnore, compute::PivotOptions::kRaise> {
static std::string name() { return "PivotOptions::UnexpectedKeyBehavior"; }
static std::string value_name(compute::PivotOptions::UnexpectedKeyBehavior value) {
switch (value) {
case compute::PivotOptions::kIgnore:
return "kIgnore";
case compute::PivotOptions::kRaise:
return "kRaise";
}
return "<INVALID>";
}
};

} // namespace internal

namespace compute {
Expand Down Expand Up @@ -101,6 +118,9 @@ static auto kTDigestOptionsType = GetFunctionOptionsType<TDigestOptions>(
DataMember("buffer_size", &TDigestOptions::buffer_size),
DataMember("skip_nulls", &TDigestOptions::skip_nulls),
DataMember("min_count", &TDigestOptions::min_count));
static auto kPivotOptionsType = GetFunctionOptionsType<PivotOptions>(
DataMember("key_names", &PivotOptions::key_names),
DataMember("unexpected_key_behavior", &PivotOptions::unexpected_key_behavior));
static auto kIndexOptionsType =
GetFunctionOptionsType<IndexOptions>(DataMember("value", &IndexOptions::value));
} // namespace
Expand Down Expand Up @@ -164,6 +184,13 @@ TDigestOptions::TDigestOptions(std::vector<double> q, uint32_t delta,
min_count{min_count} {}
constexpr char TDigestOptions::kTypeName[];

PivotOptions::PivotOptions(std::vector<std::string> key_names,
UnexpectedKeyBehavior unexpected_key_behavior)
: FunctionOptions(internal::kPivotOptionsType),
key_names(std::move(key_names)),
unexpected_key_behavior(unexpected_key_behavior) {}
PivotOptions::PivotOptions() : FunctionOptions(internal::kPivotOptionsType) {}

IndexOptions::IndexOptions(std::shared_ptr<Scalar> value)
: FunctionOptions(internal::kIndexOptionsType), value{std::move(value)} {}
IndexOptions::IndexOptions() : IndexOptions(std::make_shared<NullScalar>()) {}
Expand All @@ -177,6 +204,7 @@ void RegisterAggregateOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kVarianceOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kQuantileOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kTDigestOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kPivotOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kIndexOptionsType));
}
} // namespace internal
Expand Down
69 changes: 69 additions & 0 deletions cpp/src/arrow/compute/api_aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,75 @@ class ARROW_EXPORT TDigestOptions : public FunctionOptions {
uint32_t min_count;
};

/// \brief Control Pivot kernel behavior
///
/// These options apply to the "pivot" (TODO) and "hash_pivot" (TODO) functions.
///
/// Constraints:
/// - The corresponding `Aggregate::target` must have two FieldRef elements;
/// the first one points to the pivot key column, the second points to the
/// pivoted data column.
/// - The pivot key column must be string-like; its values will be matched
/// against `key_names` in order to dispatch the pivoted data into the
/// output.
///
/// "hash_pivot" example
/// --------------------
///
/// Assuming the following input with schema
/// `{"group": int32, "key": utf8, "value": int16}`:
/// ```
/// group | key | value
/// -----------------------------
/// 1 | height | 11
/// 1 | width | 12
/// 2 | width | 13
/// 3 | height | 14
/// 3 | depth | 15
/// ```
/// and the following settings:
/// - a hash grouping key "group"
/// - Aggregate(
/// .function = "hash_pivot",
/// .options = PivotOptions(.key_names = {"height", "width"}),
/// .target = {"key", "value"},
/// .name = {"props"})
///
/// then the output will have the schema
/// `{"group": int32, "props": struct{"height": int16, "width": int16}}`
/// and the following value:
/// ```
/// group | props
/// | height | width
/// -----------------------------
/// 1 | 11 | 12
/// 2 | null | 13
/// 3 | 14 | null
/// ```
class ARROW_EXPORT PivotOptions : public FunctionOptions {
public:
// Configure the behavior of pivot keys not in `key_names`
enum UnexpectedKeyBehavior {
// Unexpected pivot keys are ignored silently
kIgnore,
// Unexpected pivot keys return a KeyError
kRaise
};
// TODO should duplicate key behavior be configurable as well?

explicit PivotOptions(std::vector<std::string> key_names,
UnexpectedKeyBehavior unexpected_key_behavior = kIgnore);
// Default constructor for serialization
PivotOptions();
static constexpr char const kTypeName[] = "PivotOptions";
static PivotOptions Defaults() { return PivotOptions{}; }

// The values expected in the pivot key column
std::vector<std::string> key_names;
// The behavior when pivot keys not in `key_names` are encountered
UnexpectedKeyBehavior unexpected_key_behavior = kIgnore;
};

/// \brief Control Index kernel behavior
class ARROW_EXPORT IndexOptions : public FunctionOptions {
public:
Expand Down
Loading

0 comments on commit b45d585

Please sign in to comment.