From 7b2f703c595cc00fa7b71b606dc2c937f772d950 Mon Sep 17 00:00:00 2001 From: Jin Shang Date: Tue, 19 Dec 2023 10:49:40 +0800 Subject: [PATCH] avoid hash mean overflow --- cpp/src/arrow/acero/hash_aggregate_test.cc | 36 +++++++++++++++++++ .../arrow/compute/kernels/hash_aggregate.cc | 24 +++++++++---- 2 files changed, 54 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/acero/hash_aggregate_test.cc b/cpp/src/arrow/acero/hash_aggregate_test.cc index a4874f3581040..2626fd50379dd 100644 --- a/cpp/src/arrow/acero/hash_aggregate_test.cc +++ b/cpp/src/arrow/acero/hash_aggregate_test.cc @@ -1694,6 +1694,42 @@ TEST_P(GroupBy, SumMeanProductScalar) { } } +TEST_P(GroupBy, MeanOverflow) { + BatchesWithSchema input; + // would overflow if intermediate sum is integer + input.batches = { + ExecBatchFromJSON({int64(), int64()}, {ArgShape::SCALAR, ArgShape::ARRAY}, + + "[[9223372036854775805, 1], [9223372036854775805, 1], " + "[9223372036854775805, 2], [9223372036854775805, 3]]"), + ExecBatchFromJSON({int64(), int64()}, {ArgShape::SCALAR, ArgShape::ARRAY}, + "[[null, 1], [null, 1], [null, 2], [null, 3]]"), + ExecBatchFromJSON({int64(), int64()}, + "[[9223372036854775805, 1], [9223372036854775805, 2], " + "[9223372036854775805, 3]]"), + }; + input.schema = schema({field("argument", int64()), field("key", int64())}); + for (bool use_threads : {true, false}) { + SCOPED_TRACE(use_threads ? "parallel/merged" : "serial"); + ASSERT_OK_AND_ASSIGN(Datum actual, + RunGroupBy(input, {"key"}, + { + {"hash_mean", nullptr, "argument", "hash_mean"}, + }, + use_threads)); + Datum expected = ArrayFromJSON(struct_({ + field("key", int64()), + field("hash_mean", float64()), + }), + R"([ + [1, 9223372036854775805], + [2, 9223372036854775805], + [3, 9223372036854775805] + ])"); + AssertDatumsApproxEqual(expected, actual, /*verbose=*/true); + } +} + TEST_P(GroupBy, VarianceAndStddev) { auto batch = RecordBatchFromJSON( schema({field("argument", int32()), field("key", int64())}), R"([ diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index 47cae538e2e3f..1e374890e7d52 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -38,6 +38,7 @@ #include "arrow/compute/row/grouper.h" #include "arrow/record_batch.h" #include "arrow/stl_allocator.h" +#include "arrow/type_traits.h" #include "arrow/util/bit_run_reader.h" #include "arrow/util/bitmap_ops.h" #include "arrow/util/bitmap_writer.h" @@ -441,9 +442,10 @@ struct GroupedCountImpl : public GroupedAggregator { // ---------------------------------------------------------------------- // Sum/Mean/Product implementation -template +template ::Type> struct GroupedReducingAggregator : public GroupedAggregator { - using AccType = typename FindAccumulatorType::Type; + using AccType = AccumulateType; using CType = typename TypeTraits::CType; using InputCType = typename TypeTraits::CType; @@ -483,7 +485,8 @@ struct GroupedReducingAggregator : public GroupedAggregator { Status Merge(GroupedAggregator&& raw_other, const ArrayData& group_id_mapping) override { - auto other = checked_cast*>(&raw_other); + auto other = + checked_cast*>(&raw_other); CType* reduced = reduced_.mutable_data(); int64_t* counts = counts_.mutable_data(); @@ -733,9 +736,18 @@ using GroupedProductFactory = // ---------------------------------------------------------------------- // Mean implementation +template +struct GroupedMeanAccType { + using Type = typename std::conditional::value, DoubleType, + typename FindAccumulatorType::Type>::type; +}; + template -struct GroupedMeanImpl : public GroupedReducingAggregator> { - using Base = GroupedReducingAggregator>; +struct GroupedMeanImpl + : public GroupedReducingAggregator, + typename GroupedMeanAccType::Type> { + using Base = GroupedReducingAggregator, + typename GroupedMeanAccType::Type>; using CType = typename Base::CType; using InputCType = typename Base::InputCType; using MeanType = @@ -746,7 +758,7 @@ struct GroupedMeanImpl : public GroupedReducingAggregator static enable_if_number Reduce(const DataType&, const CType u, const InputCType v) { - return static_cast(to_unsigned(u) + to_unsigned(static_cast(v))); + return static_cast(u) + static_cast(v); } static CType Reduce(const DataType&, const CType u, const CType v) {