Skip to content

Commit

Permalink
avoid hash mean overflow
Browse files Browse the repository at this point in the history
  • Loading branch information
js8544 committed Dec 19, 2023
1 parent 75c6b64 commit 7b2f703
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 6 deletions.
36 changes: 36 additions & 0 deletions cpp/src/arrow/acero/hash_aggregate_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1694,6 +1694,42 @@ TEST_P(GroupBy, SumMeanProductScalar) {
}
}

TEST_P(GroupBy, MeanOverflow) {
BatchesWithSchema input;
// would overflow if intermediate sum is integer
input.batches = {
ExecBatchFromJSON({int64(), int64()}, {ArgShape::SCALAR, ArgShape::ARRAY},

"[[9223372036854775805, 1], [9223372036854775805, 1], "
"[9223372036854775805, 2], [9223372036854775805, 3]]"),
ExecBatchFromJSON({int64(), int64()}, {ArgShape::SCALAR, ArgShape::ARRAY},
"[[null, 1], [null, 1], [null, 2], [null, 3]]"),
ExecBatchFromJSON({int64(), int64()},
"[[9223372036854775805, 1], [9223372036854775805, 2], "
"[9223372036854775805, 3]]"),
};
input.schema = schema({field("argument", int64()), field("key", int64())});
for (bool use_threads : {true, false}) {
SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
ASSERT_OK_AND_ASSIGN(Datum actual,
RunGroupBy(input, {"key"},
{
{"hash_mean", nullptr, "argument", "hash_mean"},
},
use_threads));
Datum expected = ArrayFromJSON(struct_({
field("key", int64()),
field("hash_mean", float64()),
}),
R"([
[1, 9223372036854775805],
[2, 9223372036854775805],
[3, 9223372036854775805]
])");
AssertDatumsApproxEqual(expected, actual, /*verbose=*/true);
}
}

TEST_P(GroupBy, VarianceAndStddev) {
auto batch = RecordBatchFromJSON(
schema({field("argument", int32()), field("key", int64())}), R"([
Expand Down
24 changes: 18 additions & 6 deletions cpp/src/arrow/compute/kernels/hash_aggregate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "arrow/compute/row/grouper.h"
#include "arrow/record_batch.h"
#include "arrow/stl_allocator.h"
#include "arrow/type_traits.h"
#include "arrow/util/bit_run_reader.h"
#include "arrow/util/bitmap_ops.h"
#include "arrow/util/bitmap_writer.h"
Expand Down Expand Up @@ -441,9 +442,10 @@ struct GroupedCountImpl : public GroupedAggregator {
// ----------------------------------------------------------------------
// Sum/Mean/Product implementation

template <typename Type, typename Impl>
template <typename Type, typename Impl,
typename AccumulateType = typename FindAccumulatorType<Type>::Type>
struct GroupedReducingAggregator : public GroupedAggregator {
using AccType = typename FindAccumulatorType<Type>::Type;
using AccType = AccumulateType;
using CType = typename TypeTraits<AccType>::CType;
using InputCType = typename TypeTraits<Type>::CType;

Expand Down Expand Up @@ -483,7 +485,8 @@ struct GroupedReducingAggregator : public GroupedAggregator {

Status Merge(GroupedAggregator&& raw_other,
const ArrayData& group_id_mapping) override {
auto other = checked_cast<GroupedReducingAggregator<Type, Impl>*>(&raw_other);
auto other =
checked_cast<GroupedReducingAggregator<Type, Impl, AccType>*>(&raw_other);

CType* reduced = reduced_.mutable_data();
int64_t* counts = counts_.mutable_data();
Expand Down Expand Up @@ -733,9 +736,18 @@ using GroupedProductFactory =
// ----------------------------------------------------------------------
// Mean implementation

template <typename T>
struct GroupedMeanAccType {
using Type = typename std::conditional<is_number_type<T>::value, DoubleType,
typename FindAccumulatorType<T>::Type>::type;
};

template <typename Type>
struct GroupedMeanImpl : public GroupedReducingAggregator<Type, GroupedMeanImpl<Type>> {
using Base = GroupedReducingAggregator<Type, GroupedMeanImpl<Type>>;
struct GroupedMeanImpl
: public GroupedReducingAggregator<Type, GroupedMeanImpl<Type>,
typename GroupedMeanAccType<Type>::Type> {
using Base = GroupedReducingAggregator<Type, GroupedMeanImpl<Type>,
typename GroupedMeanAccType<Type>::Type>;
using CType = typename Base::CType;
using InputCType = typename Base::InputCType;
using MeanType =
Expand All @@ -746,7 +758,7 @@ struct GroupedMeanImpl : public GroupedReducingAggregator<Type, GroupedMeanImpl<
template <typename T = Type>
static enable_if_number<T, CType> Reduce(const DataType&, const CType u,
const InputCType v) {
return static_cast<CType>(to_unsigned(u) + to_unsigned(static_cast<CType>(v)));
return static_cast<CType>(u) + static_cast<CType>(v);
}

static CType Reduce(const DataType&, const CType u, const CType v) {
Expand Down

0 comments on commit 7b2f703

Please sign in to comment.