Skip to content

Commit

Permalink
apacheGH-43768: [C++] Fix the case when boolean_{any|all} meets const…
Browse files Browse the repository at this point in the history
…ant input with length in Acero (apache#43799)

### Rationale for this change

See apache#43768

### What changes are included in this PR?

Fix the case when boolean_{any|all} meets constant input with length in Acero

### Are these changes tested?

Yes

### Are there any user-facing changes?

no

* GitHub Issue: apache#43768

Lead-authored-by: mwish <maplewish117@gmail.com>
Co-authored-by: mwish <1506118561@qq.com>
Co-authored-by: Rossi Sun <zanmato1984@gmail.com>
Signed-off-by: mwish <maplewish117@gmail.com>
  • Loading branch information
3 people authored Sep 2, 2024
1 parent 44d3f76 commit 9cafbb2
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 8 deletions.
52 changes: 52 additions & 0 deletions cpp/src/arrow/acero/aggregate_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,5 +210,57 @@ TEST(GroupByNode, NoSkipNulls) {
AssertExecBatchesEqualIgnoringOrder(out_schema, {expected_batch}, out_batches.batches);
}

TEST(ScalarAggregateNode, AnyAll) {
// GH-43768: boolean_any and boolean_all with constant input should work well
// when min_count != 0.
std::shared_ptr<Schema> in_schema = schema({field("not_used", int32())});
std::shared_ptr<Schema> out_schema = schema({field("agg_out", boolean())});
struct AnyAllCase {
std::string batches_json;
Expression literal;
std::string expected_json;
bool skip_nulls = false;
uint32_t min_count = 2;
};
std::vector<AnyAllCase> cases{
{"[[42], [42], [42], [42]]", literal(true), "[[true]]"},
{"[[42], [42], [42], [42]]", literal(false), "[[false]]"},
{"[[42], [42], [42], [42]]", literal(BooleanScalar{}), "[[null]]"},
{"[[42]]", literal(true), "[[null]]"},
{"[[42], [42], [42]]", literal(true), "[[true]]"},
{"[[42], [42], [42]]", literal(true), "[[null]]", /*skip_nulls=*/false,
/*min_count=*/4},
{"[[42], [42], [42], [42]]", literal(BooleanScalar{}), "[[null]]",
/*skip_nulls=*/true},
};
for (const AnyAllCase& any_all_case : cases) {
for (auto func_name : {"any", "all"}) {
std::vector<ExecBatch> batches{
ExecBatchFromJSON({int32()}, any_all_case.batches_json)};
std::vector<Aggregate> aggregates = {
Aggregate(func_name,
std::make_shared<compute::ScalarAggregateOptions>(
/*skip_nulls=*/any_all_case.skip_nulls,
/*min_count=*/any_all_case.min_count),
FieldRef("literal"))};

// And a projection to make the input including a Scalar Boolean
Declaration plan = Declaration::Sequence(
{{"exec_batch_source", ExecBatchSourceNodeOptions(in_schema, batches)},
{"project", ProjectNodeOptions({any_all_case.literal}, {"literal"})},
{"aggregate", AggregateNodeOptions(aggregates)}});

ASSERT_OK_AND_ASSIGN(BatchesWithCommonSchema out_batches,
DeclarationToExecBatches(plan));

ExecBatch expected_batch =
ExecBatchFromJSON({boolean()}, any_all_case.expected_json);

AssertExecBatchesEqualIgnoringOrder(out_schema, {expected_batch},
out_batches.batches);
}
}
}

} // namespace acero
} // namespace arrow
16 changes: 8 additions & 8 deletions cpp/src/arrow/compute/kernels/aggregate_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -532,13 +532,13 @@ struct BooleanAnyImpl : public ScalarAggregator {
}
if (batch[0].is_scalar()) {
const Scalar& scalar = *batch[0].scalar;
this->has_nulls = !scalar.is_valid;
this->any = scalar.is_valid && checked_cast<const BooleanScalar&>(scalar).value;
this->count += scalar.is_valid;
this->has_nulls |= !scalar.is_valid;
this->any |= scalar.is_valid && checked_cast<const BooleanScalar&>(scalar).value;
this->count += scalar.is_valid * batch.length;
return Status::OK();
}
const ArraySpan& data = batch[0].array;
this->has_nulls = data.GetNullCount() > 0;
this->has_nulls |= data.GetNullCount() > 0;
this->count += data.length - data.GetNullCount();
arrow::internal::OptionalBinaryBitBlockCounter counter(
data.buffers[0].data, data.offset, data.buffers[1].data, data.offset,
Expand Down Expand Up @@ -603,13 +603,13 @@ struct BooleanAllImpl : public ScalarAggregator {
}
if (batch[0].is_scalar()) {
const Scalar& scalar = *batch[0].scalar;
this->has_nulls = !scalar.is_valid;
this->count += scalar.is_valid;
this->all = !scalar.is_valid || checked_cast<const BooleanScalar&>(scalar).value;
this->has_nulls |= !scalar.is_valid;
this->count += scalar.is_valid * batch.length;
this->all &= !scalar.is_valid || checked_cast<const BooleanScalar&>(scalar).value;
return Status::OK();
}
const ArraySpan& data = batch[0].array;
this->has_nulls = data.GetNullCount() > 0;
this->has_nulls |= data.GetNullCount() > 0;
this->count += data.length - data.GetNullCount();
arrow::internal::OptionalBinaryBitBlockCounter counter(
data.buffers[1].data, data.offset, data.buffers[0].data, data.offset,
Expand Down

0 comments on commit 9cafbb2

Please sign in to comment.