From 4256ea3879f4d20cb907ed7a65456845d64796d4 Mon Sep 17 00:00:00 2001 From: Wei He Date: Thu, 8 Feb 2024 15:26:46 -0800 Subject: [PATCH] SimpleAggregateAdapter with FunctionState prototype Differential Revision: D53556520 --- velox/exec/Aggregate.h | 4 + velox/exec/AggregateInfo.cpp | 2 + velox/exec/SimpleAggregateAdapterExperiment.h | 568 ++++++++++++++++++ velox/exec/tests/CMakeLists.txt | 10 +- .../SimpleAggregateAdapterExperimentTest.cpp | 49 ++ .../SimpleAggregateFunctionsRegistration.h | 3 + velox/exec/tests/SimpleSumAggregate.cpp | 115 ++++ 7 files changed, 748 insertions(+), 3 deletions(-) create mode 100644 velox/exec/SimpleAggregateAdapterExperiment.h create mode 100644 velox/exec/tests/SimpleAggregateAdapterExperimentTest.cpp create mode 100644 velox/exec/tests/SimpleSumAggregate.cpp diff --git a/velox/exec/Aggregate.h b/velox/exec/Aggregate.h index d4f71c8ee4c2..b532a3a36b55 100644 --- a/velox/exec/Aggregate.h +++ b/velox/exec/Aggregate.h @@ -51,6 +51,10 @@ class Aggregate { return resultType_; } + virtual void initialize( + const TypePtr& resultType, + const std::vector& args) {} + // Returns the fixed number of bytes the accumulator takes on a group // row. Variable width accumulators will reference the variable // width part of the state from the fixed part. diff --git a/velox/exec/AggregateInfo.cpp b/velox/exec/AggregateInfo.cpp index 47f415b98e59..42d0bbdacabc 100644 --- a/velox/exec/AggregateInfo.cpp +++ b/velox/exec/AggregateInfo.cpp @@ -107,6 +107,8 @@ std::vector toAggregateInfo( aggResultType, operatorCtx.driverCtx()->queryConfig()); + info.function->initialize(aggResultType, info.constantInputs); + if (!isStreaming) { auto lambdas = extractLambdaInputs(aggregate); if (!lambdas.empty()) { diff --git a/velox/exec/SimpleAggregateAdapterExperiment.h b/velox/exec/SimpleAggregateAdapterExperiment.h new file mode 100644 index 000000000000..cbfc4e53b2a5 --- /dev/null +++ b/velox/exec/SimpleAggregateAdapterExperiment.h @@ -0,0 +1,568 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/exec/Aggregate.h" +#include "velox/expression/VectorReaders.h" +#include "velox/expression/VectorWriters.h" + +namespace facebook::velox::exec { + +// The writer type of T used in simple UDAF interface. An instance of +// out_type allows writing one row into the output vector. +template +using out_type = typename VectorExec::template resolver::out_type; + +// The reader type of T used in simple UDAF interface. An instance of +// arg_type allows reading one row from the input vector. This is used for UDAFs +// that have the default null behavior. +template +using arg_type = typename VectorExec::template resolver::in_type; + +// The reader type of T used in simple UDAF interface. An instance of +// arg_type allows reading one row from the input vector. This is used for UDAFs +// that have non-default null behavior. +template +using optional_arg_type = OptionalAccessor; + +template +class SimpleAggregateAdapterExperiment : public Aggregate { + public: + explicit SimpleAggregateAdapterExperiment(TypePtr resultType) + : Aggregate(std::move(resultType)) {} + + typename FUNC::FunctionState state_; + + void initialize( + const TypePtr& resultType, + const std::vector& constantInputs) override { + FUNC::initialize(state_, resultType, constantInputs); + } + + // Assume most aggregate functions have fixed-size accumulators. Functions + // that + // have non-fixed-size accumulators should overwrite `is_fixed_size_` in their + // accumulator structs. + template + struct accumulator_is_fixed_size : std::true_type {}; + + template + struct accumulator_is_fixed_size> + : std::integral_constant {}; + + // Assume most aggregate functions have default null behavior, i.e., ignoring + // rows that have null values in raw input and intermediate results, and + // returning null for groups of no input rows or only null rows. + // For example: select sum(col0) + // from (values (1, 10), (2, 20), (3, 30)) as t(col0, col1) + // where col0 > 10; -- NULL + // Functions that have non-default null behavior should overwrite + // `default_null_behavior_`. + // All accumulators are initialized to NULL before the aggregation starts. + // However, for functions that have default and non-default null behaviors, + // there are a couple of differences in their implementations. + // 1. When default_null_behavior_ is true, authors define + // void AccumulatorType::addInput(HashStringAllocator* allocator, + // exec::arg_type arg1, ...) + // void AccumulatorType::combine(HashStringAllocator* allocator, + // exec::arg_type arg) + // These functions only receive non-null input values. Input rows that contain + // at least one NULL argument are ignored. The accumulator of a group is set + // to non-null if at least one input is added to this group through addInput() + // or combine(). Similarly, authors define + // bool AccumulatorType::writeIntermediateResult( + // exec::out_type&out) + // bool AccumulatorType::writeFinalResult(exec::out_type&out) + // These functions are only called on groups of non-null accumulators. Groups + // that have NULL accumulators automatically become NULL in the result vector. + // These functions also return a bool indicating whether the current group + // should be a NULL in the result vector. + // + // 2. When default_null_behavior_ is false, authors define + // bool AccumulatorType::addInput(HashStringAllocator* allocator, + // exec::optional_arg_type arg1, ...) + // bool AccumulatorType::combine( + // HashStringAllocator* allocator, + // exec::optional_arg_type arg) + // These functions receive both non-null and null inputs. They return a bool + // indicating whether to set the current group's accumulator to non-null. If + // the accumulator of a group is already non-NULL, returning false from + // addInput() or combine() doesn't change this group's nullness. Authors also + // define + // bool AccumulatorType::writeIntermediateResult( + // bool nonNullGroup, + // exec::out_type& out) + // bool AccumulatorType::writeFinalResult( + // bool nonNullGroup, + // exec::out_type& out) + // These functions are called on groups of both non-null and null + // accumulators. These functions also return a bool indicating whether the + // current group should be a NULL in the result vector. + template + struct aggregate_default_null_behavior : std::true_type {}; + + template + struct aggregate_default_null_behavior< + T, + std::void_t> + : std::integral_constant {}; + + // Assume most aggregate functions do not use external memory. Functions that + // use external memory should overwrite `use_external_memory_` in their + // accumulator structs. + template + struct accumulator_use_external_memory : std::false_type {}; + + template + struct accumulator_use_external_memory< + T, + std::void_t> + : std::integral_constant {}; + + // Whether the accumulator type defines its destroy() method or not. If it is + // defined, we call the accumulator's destroy() in + // SimpleAggregateAdapter::destroy(). + template + struct accumulator_custom_destroy : std::false_type {}; + + template + struct accumulator_custom_destroy> + : std::true_type {}; + + // Whether the function defines its toIntermediate() method or not. If it is + // defined, SimpleAggregateAdapter::supportToIntermediate() returns true. + // Otherwise, SimpleAggregateAdapter::supportToIntermediate() returns false + // and SimpleAggregateAdapter::toIntermediate() is empty. + template + struct support_to_intermediate : std::false_type {}; + + template + struct support_to_intermediate> + : std::true_type {}; + + static constexpr bool aggregate_default_null_behavior_ = + aggregate_default_null_behavior::value; + + static constexpr bool accumulator_is_fixed_size_ = + accumulator_is_fixed_size::value; + + static constexpr bool accumulator_use_external_memory_ = + accumulator_use_external_memory::value; + + static constexpr bool accumulator_custom_destroy_ = + accumulator_custom_destroy::value; + + static constexpr bool support_to_intermediate_ = + support_to_intermediate::value; + + bool isFixedSize() const override { + return accumulator_is_fixed_size_; + } + + bool accumulatorUsesExternalMemory() const override { + return accumulator_use_external_memory_; + } + + int32_t accumulatorFixedWidthSize() const override { + return sizeof(typename FUNC::AccumulatorType); + } + + void initializeNewGroups( + char** groups, + folly::Range indices) override { + setAllNulls(groups, indices); + for (auto i : indices) { + new (groups[i] + offset_) + typename FUNC::AccumulatorType(allocator_, state_); + } + } + + // Add raw input to accumulators. If the simple aggregation function has + // default null behavior, input rows that has nulls are skipped. Otherwise, + // the accumulator type's addInput() method handles null inputs. + void addRawInput( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + if (inputDecoded_.size() < args.size()) { + inputDecoded_.resize(args.size()); + } + + for (column_index_t i = 0; i < args.size(); ++i) { + inputDecoded_[i].decode(*args[i], rows); + } + + addRawInputImpl( + groups, rows, std::make_index_sequence{}); + } + + // Similar to addRawInput, but add inputs to one single accumulator. + void addSingleGroupRawInput( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + if (inputDecoded_.size() < args.size()) { + inputDecoded_.resize(args.size()); + } + + for (column_index_t i = 0; i < args.size(); ++i) { + inputDecoded_[i].decode(*args[i], rows); + } + + addSingleGroupRawInputImpl( + group, rows, std::make_index_sequence{}); + } + + bool supportsToIntermediate() const override { + return support_to_intermediate_; + } + + void toIntermediate( + const SelectivityVector& rows, + std::vector& args, + VectorPtr& result) const override { + if constexpr (support_to_intermediate_) { + std::vector inputDecoded{args.size()}; + for (column_index_t i = 0; i < args.size(); ++i) { + inputDecoded[i].decode(*args[i], rows); + } + + toIntermediateImpl( + inputDecoded, + rows, + result, + std::make_index_sequence{}); + } else { + VELOX_UNREACHABLE( + "toIntermediate should only be called when support_to_intermediate_ is true."); + } + } + + // Add intermediate results to accumulators. If the simple aggregation + // function has default null behavior, intermediate result rows that has nulls + // are skipped. Otherwise, the accumulator type's combine() method handles + // nulls. + void addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + intermediateDecoded_.decode(*args[0], rows); + + addIntermediateResultsImpl(groups, rows); + } + + // Similar to addIntermediateResults, but add intermediate results to one + // single accumulator. + void addSingleGroupIntermediateResults( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /* mayPushdown */) override { + intermediateDecoded_.decode(*args[0], rows); + + addSingleGroupIntermediateResultsImpl(group, rows); + } + + // Extract intermediate results to a vector. + void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) + override { + VectorWriter writer; + auto vector = (*result) + ->as::type>(); + vector->resize(numGroups); + writer.init(*vector); + + for (auto i = 0; i < numGroups; ++i) { + writer.setOffset(i); + auto group = value(groups[i]); + + if constexpr (aggregate_default_null_behavior_) { + if (isNull(groups[i])) { + writer.commitNull(); + } else { + bool nonNull = + group->writeIntermediateResult(writer.current(), state_); + writer.commit(nonNull); + } + } else { + bool nonNull = group->writeIntermediateResult( + !isNull(groups[i]), writer.current()); + writer.commit(nonNull); + } + } + writer.finish(); + } + + // Extract final results to a vector. + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) + override { + auto flatResult = + (*result) + ->as::type>(); + flatResult->resize(numGroups); + + VectorWriter writer; + writer.init(*flatResult); + + for (auto i = 0; i < numGroups; ++i) { + writer.setOffset(i); + auto group = value(groups[i]); + + if constexpr (aggregate_default_null_behavior_) { + if (isNull(groups[i])) { + writer.commitNull(); + } else { + bool nonNull = group->writeFinalResult(writer.current(), state_); + writer.commit(nonNull); + } + } else { + bool nonNull = + group->writeFinalResult(!isNull(groups[i]), writer.current()); + writer.commit(nonNull); + } + } + writer.finish(); + } + + void destroy(folly::Range groups) override { + if constexpr (accumulator_custom_destroy_) { + for (auto group : groups) { + auto accumulator = value(group); + if (!isNull(group)) { + accumulator->destroy(allocator_); + } + } + } + destroyAccumulators(groups); + } + + private: + template + void addRawInputImpl( + char** groups, + const SelectivityVector& rows, + std::index_sequence) { + std::tuple>...> + readers{&inputDecoded_[Is]...}; + + if constexpr (aggregate_default_null_behavior_) { + rows.applyToSelected([&](auto row) { + // If any input is null, we ignore the whole row. + if (!(std::get(readers).isSet(row) && ...)) { + return; + } + std::optional> tracker; + if constexpr (!accumulator_is_fixed_size_) { + tracker.emplace(groups[row][rowSizeOffset_], *allocator_); + } + auto group = value(groups[row]); + group->addInput(allocator_, std::get(readers)[row]..., state_); + clearNull(groups[row]); + }); + } else { + rows.applyToSelected([&](auto row) { + std::optional> tracker; + if constexpr (!accumulator_is_fixed_size_) { + tracker.emplace(groups[row][rowSizeOffset_], *allocator_); + } + auto group = value(groups[row]); + bool nonNull = group->addInput( + allocator_, + OptionalAccessor>{ + &std::get(readers), (int64_t)row}..., + state_); + if (nonNull) { + clearNull(groups[row]); + } + }); + } + } + + template + void addSingleGroupRawInputImpl( + char* group, + const SelectivityVector& rows, + std::index_sequence) { + std::tuple>...> + readers{&inputDecoded_[Is]...}; + auto accumulator = value(group); + + if constexpr (aggregate_default_null_behavior_) { + rows.applyToSelected([&](auto row) { + // If any input is null, we ignore the whole row. + if (!(std::get(readers).isSet(row) && ...)) { + return; + } + std::optional> tracker; + if constexpr (!accumulator_is_fixed_size_) { + tracker.emplace(group[rowSizeOffset_], *allocator_); + } + accumulator->addInput( + allocator_, std::get(readers)[row]..., state_); + clearNull(group); + }); + } else { + rows.applyToSelected([&](auto row) { + std::optional> tracker; + if constexpr (!accumulator_is_fixed_size_) { + tracker.emplace(group[rowSizeOffset_], *allocator_); + } + bool nonNull = accumulator->addInput( + allocator_, + OptionalAccessor>{ + &std::get(readers), (int64_t)row}..., + state_); + if (nonNull) { + clearNull(group); + } + }); + } + } + + template + void toIntermediateImpl( + const std::vector& inputDecoded, + const SelectivityVector& rows, + VectorPtr& result, + std::index_sequence) const { + std::tuple>...> + readers{&inputDecoded[Is]...}; + + VELOX_CHECK(result); + result->ensureWritable(rows); + auto* rawNulls = result->mutableRawNulls(); + bits::fillBits(rawNulls, 0, result->size(), bits::kNull); + + constexpr auto intermediateKind = + SimpleTypeTrait::typeKind; + auto* flatResult = + result->as::type>(); + exec::VectorWriter writer; + writer.init(*flatResult); + + if constexpr (aggregate_default_null_behavior_) { + rows.applyToSelected([&](auto row) { + writer.setOffset(row); + // If any input is null, we ignore the whole row. + if (!(std::get(readers).isSet(row) && ...)) { + writer.commitNull(); + return; + } + bool nonNull = FUNC::toIntermediate( + writer.current(), std::get(readers)[row]..., state_); + writer.commit(nonNull); + }); + writer.finish(); + } else { + rows.applyToSelected([&](auto row) { + writer.setOffset(row); + bool nonNull = FUNC::toIntermediate( + writer.current(), + OptionalAccessor>{ + &std::get(readers), (int64_t)row}..., + state_); + writer.commit(nonNull); + }); + writer.finish(); + } + } + + // Implementation of addIntermediateResults when the intermediate type is not + // a Row type. + void addIntermediateResultsImpl( + char** groups, + const SelectivityVector& rows) { + VectorReader reader(&intermediateDecoded_); + + if constexpr (aggregate_default_null_behavior_) { + rows.applyToSelected([&](auto row) { + if (!reader.isSet(row)) { + return; + } + std::optional> tracker; + if constexpr (!accumulator_is_fixed_size_) { + tracker.emplace(groups[row][rowSizeOffset_], *allocator_); + } + auto group = value(groups[row]); + group->combine(allocator_, reader[row], state_); + clearNull(groups[row]); + }); + } else { + rows.applyToSelected([&](auto row) { + std::optional> tracker; + if constexpr (!accumulator_is_fixed_size_) { + tracker.emplace(groups[row][rowSizeOffset_], *allocator_); + } + auto group = value(groups[row]); + bool nonNull = group->combine( + allocator_, + OptionalAccessor{ + &reader, (int64_t)row}, + state_); + if (nonNull) { + clearNull(groups[row]); + } + }); + } + } + + // Implementation of addSingleGroupIntermediateResults when the intermediate + // type is not a Row type. + void addSingleGroupIntermediateResultsImpl( + char* group, + const SelectivityVector& rows) { + VectorReader reader(&intermediateDecoded_); + auto accumulator = value(group); + + if constexpr (aggregate_default_null_behavior_) { + rows.applyToSelected([&](auto row) { + if (!reader.isSet(row)) { + return; + } + std::optional> tracker; + if constexpr (!accumulator_is_fixed_size_) { + tracker.emplace(group[rowSizeOffset_], *allocator_); + } + accumulator->combine(allocator_, reader[row], state_); + clearNull(group); + }); + } else { + rows.applyToSelected([&](auto row) { + std::optional> tracker; + if constexpr (!accumulator_is_fixed_size_) { + tracker.emplace(group[rowSizeOffset_], *allocator_); + } + bool nonNull = accumulator->combine( + allocator_, + OptionalAccessor{ + &reader, (int64_t)row}, + state_); + if (nonNull) { + clearNull(group); + } + }); + } + } + + std::vector inputDecoded_; + DecodedVector intermediateDecoded_; +}; + +} // namespace facebook::velox::exec diff --git a/velox/exec/tests/CMakeLists.txt b/velox/exec/tests/CMakeLists.txt index baf67dd3ea1d..aebb8621edeb 100644 --- a/velox/exec/tests/CMakeLists.txt +++ b/velox/exec/tests/CMakeLists.txt @@ -233,13 +233,17 @@ target_link_libraries( gtest_main) add_library(velox_simple_aggregate SimpleAverageAggregate.cpp - SimpleArrayAggAggregate.cpp) + SimpleArrayAggAggregate.cpp + SimpleSumAggregate.cpp) target_link_libraries(velox_simple_aggregate velox_exec velox_expression velox_expression_functions velox_aggregates) -add_executable(velox_simple_aggregate_test SimpleAggregateAdapterTest.cpp - Main.cpp) +add_executable( + velox_simple_aggregate_test + SimpleAggregateAdapterTest.cpp + SimpleAggregateAdapterExperimentTest.cpp + Main.cpp) target_link_libraries( velox_simple_aggregate_test velox_simple_aggregate velox_exec diff --git a/velox/exec/tests/SimpleAggregateAdapterExperimentTest.cpp b/velox/exec/tests/SimpleAggregateAdapterExperimentTest.cpp new file mode 100644 index 000000000000..17cf56c77d14 --- /dev/null +++ b/velox/exec/tests/SimpleAggregateAdapterExperimentTest.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/Aggregate.h" +#include "velox/exec/tests/SimpleAggregateFunctionsRegistration.h" +#include "velox/functions/lib/aggregates/tests/utils/AggregationTestBase.h" + +using namespace facebook::velox::exec; +using namespace facebook::velox::exec::test; +using facebook::velox::functions::aggregate::test::AggregationTestBase; + +namespace facebook::velox::aggregate::test { +namespace { + +class SimpleSumAggregationTest : public AggregationTestBase { + protected: + void SetUp() override { + AggregationTestBase::SetUp(); + allowInputShuffle(); + + registerSimpleSumAggregate("simple_sum"); + } +}; + +TEST_F(SimpleSumAggregationTest, basic) { + auto data = makeRowVector({ + makeFlatVector({1, 2, 3, 4}), + makeFlatVector({true, true, false, false}), + }); + auto expected = makeRowVector( + {makeFlatVector({false, true}), makeFlatVector({-7, -3})}); + testAggregations({data}, {"c1"}, {"simple_sum(c0)"}, {expected}); +} + +} // namespace +} // namespace facebook::velox::aggregate::test diff --git a/velox/exec/tests/SimpleAggregateFunctionsRegistration.h b/velox/exec/tests/SimpleAggregateFunctionsRegistration.h index 06095ab92e61..c389640a2009 100644 --- a/velox/exec/tests/SimpleAggregateFunctionsRegistration.h +++ b/velox/exec/tests/SimpleAggregateFunctionsRegistration.h @@ -28,4 +28,7 @@ exec::AggregateRegistrationResult registerSimpleAverageAggregate( exec::AggregateRegistrationResult registerSimpleArrayAggAggregate( const std::string& name); +exec::AggregateRegistrationResult registerSimpleSumAggregate( + const std::string& name); + } // namespace facebook::velox::aggregate diff --git a/velox/exec/tests/SimpleSumAggregate.cpp b/velox/exec/tests/SimpleSumAggregate.cpp new file mode 100644 index 000000000000..b634293f010f --- /dev/null +++ b/velox/exec/tests/SimpleSumAggregate.cpp @@ -0,0 +1,115 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/Aggregate.h" +#include "velox/exec/SimpleAggregateAdapterExperiment.h" +#include "velox/expression/FunctionSignature.h" +#include "velox/expression/VectorWriters.h" + +using namespace facebook::velox::exec; + +namespace facebook::velox::aggregate { + +namespace { + +// Returns negative sum of input values. +class SumAggregate { + public: + // Type(s) of input vector(s) wrapped in Row. + using InputType = Row; + + // Type of intermediate result vector wrapped in Row. + using IntermediateType = int64_t; + + // Type of output vector. + using OutputType = int64_t; + + struct FunctionState { + bool flag_; + }; + + static void initialize( + FunctionState& state, + const TypePtr& type, + const std::vector& constantInputs) { + state.flag_ = false; + } + + struct AccumulatorType { + int64_t sum_; + + AccumulatorType() = delete; + + // Constructor used in initializeNewGroups(). + explicit AccumulatorType( + HashStringAllocator* /*allocator*/, + const FunctionState& /*state*/) { + sum_ = 0; + } + + void addInput( + HashStringAllocator* /*allocator*/, + int64_t data, + const FunctionState& /*state*/) { + sum_ += data; + } + + void combine( + HashStringAllocator* /*allocator*/, + int64_t other, + const FunctionState& /*state*/) { + sum_ += other; + } + + bool writeFinalResult(int64_t& out, const FunctionState& state) { + out = state.flag_ ? sum_ : -sum_; + return true; + } + + bool writeIntermediateResult(int64_t& out, const FunctionState& /*state*/) { + out = sum_; + return true; + } + }; +}; + +} // namespace + +exec::AggregateRegistrationResult registerSimpleSumAggregate( + const std::string& name) { + std::vector> signatures; + signatures.push_back(exec::AggregateFunctionSignatureBuilder() + .returnType("bigint") + .intermediateType("bigint") + .argumentType("bigint") + .build()); + + return exec::registerAggregateFunction( + name, + std::move(signatures), + [name]( + core::AggregationNode::Step step, + const std::vector& argTypes, + const TypePtr& resultType, + const core::QueryConfig& /*config*/) + -> std::unique_ptr { + return std::make_unique>( + resultType); + }, + true); +} + +} // namespace facebook::velox::aggregate