From 9cafbb26681a1488c16edaec231ba55c21543e3a Mon Sep 17 00:00:00 2001 From: mwish Date: Mon, 2 Sep 2024 23:38:24 +0800 Subject: [PATCH] GH-43768: [C++] Fix the case when boolean_{any|all} meets constant input with length in Acero (#43799) ### Rationale for this change See https://github.com/apache/arrow/issues/43768 ### What changes are included in this PR? Fix the case when boolean_{any|all} meets constant input with length in Acero ### Are these changes tested? Yes ### Are there any user-facing changes? no * GitHub Issue: #43768 Lead-authored-by: mwish Co-authored-by: mwish <1506118561@qq.com> Co-authored-by: Rossi Sun Signed-off-by: mwish --- cpp/src/arrow/acero/aggregate_node_test.cc | 52 +++++++++++++++++++ .../arrow/compute/kernels/aggregate_basic.cc | 16 +++--- 2 files changed, 60 insertions(+), 8 deletions(-) diff --git a/cpp/src/arrow/acero/aggregate_node_test.cc b/cpp/src/arrow/acero/aggregate_node_test.cc index d398fb24b73d5..c623271db9fb4 100644 --- a/cpp/src/arrow/acero/aggregate_node_test.cc +++ b/cpp/src/arrow/acero/aggregate_node_test.cc @@ -210,5 +210,57 @@ TEST(GroupByNode, NoSkipNulls) { AssertExecBatchesEqualIgnoringOrder(out_schema, {expected_batch}, out_batches.batches); } +TEST(ScalarAggregateNode, AnyAll) { + // GH-43768: boolean_any and boolean_all with constant input should work well + // when min_count != 0. + std::shared_ptr in_schema = schema({field("not_used", int32())}); + std::shared_ptr out_schema = schema({field("agg_out", boolean())}); + struct AnyAllCase { + std::string batches_json; + Expression literal; + std::string expected_json; + bool skip_nulls = false; + uint32_t min_count = 2; + }; + std::vector cases{ + {"[[42], [42], [42], [42]]", literal(true), "[[true]]"}, + {"[[42], [42], [42], [42]]", literal(false), "[[false]]"}, + {"[[42], [42], [42], [42]]", literal(BooleanScalar{}), "[[null]]"}, + {"[[42]]", literal(true), "[[null]]"}, + {"[[42], [42], [42]]", literal(true), "[[true]]"}, + {"[[42], [42], [42]]", literal(true), "[[null]]", /*skip_nulls=*/false, + /*min_count=*/4}, + {"[[42], [42], [42], [42]]", literal(BooleanScalar{}), "[[null]]", + /*skip_nulls=*/true}, + }; + for (const AnyAllCase& any_all_case : cases) { + for (auto func_name : {"any", "all"}) { + std::vector batches{ + ExecBatchFromJSON({int32()}, any_all_case.batches_json)}; + std::vector aggregates = { + Aggregate(func_name, + std::make_shared( + /*skip_nulls=*/any_all_case.skip_nulls, + /*min_count=*/any_all_case.min_count), + FieldRef("literal"))}; + + // And a projection to make the input including a Scalar Boolean + Declaration plan = Declaration::Sequence( + {{"exec_batch_source", ExecBatchSourceNodeOptions(in_schema, batches)}, + {"project", ProjectNodeOptions({any_all_case.literal}, {"literal"})}, + {"aggregate", AggregateNodeOptions(aggregates)}}); + + ASSERT_OK_AND_ASSIGN(BatchesWithCommonSchema out_batches, + DeclarationToExecBatches(plan)); + + ExecBatch expected_batch = + ExecBatchFromJSON({boolean()}, any_all_case.expected_json); + + AssertExecBatchesEqualIgnoringOrder(out_schema, {expected_batch}, + out_batches.batches); + } + } +} + } // namespace acero } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.cc index 1fbcd6a249093..c5e0e6fd6e977 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc @@ -532,13 +532,13 @@ struct BooleanAnyImpl : public ScalarAggregator { } if (batch[0].is_scalar()) { const Scalar& scalar = *batch[0].scalar; - this->has_nulls = !scalar.is_valid; - this->any = scalar.is_valid && checked_cast(scalar).value; - this->count += scalar.is_valid; + this->has_nulls |= !scalar.is_valid; + this->any |= scalar.is_valid && checked_cast(scalar).value; + this->count += scalar.is_valid * batch.length; return Status::OK(); } const ArraySpan& data = batch[0].array; - this->has_nulls = data.GetNullCount() > 0; + this->has_nulls |= data.GetNullCount() > 0; this->count += data.length - data.GetNullCount(); arrow::internal::OptionalBinaryBitBlockCounter counter( data.buffers[0].data, data.offset, data.buffers[1].data, data.offset, @@ -603,13 +603,13 @@ struct BooleanAllImpl : public ScalarAggregator { } if (batch[0].is_scalar()) { const Scalar& scalar = *batch[0].scalar; - this->has_nulls = !scalar.is_valid; - this->count += scalar.is_valid; - this->all = !scalar.is_valid || checked_cast(scalar).value; + this->has_nulls |= !scalar.is_valid; + this->count += scalar.is_valid * batch.length; + this->all &= !scalar.is_valid || checked_cast(scalar).value; return Status::OK(); } const ArraySpan& data = batch[0].array; - this->has_nulls = data.GetNullCount() > 0; + this->has_nulls |= data.GetNullCount() > 0; this->count += data.length - data.GetNullCount(); arrow::internal::OptionalBinaryBitBlockCounter counter( data.buffers[1].data, data.offset, data.buffers[0].data, data.offset,