Skip to content

Commit

Permalink
[fix](function)adjust aggregate function's nullable property (#37330)
Browse files Browse the repository at this point in the history
## Proposed changes

in this PR have change some age function nullable property.
aggregate functions can be divided into 3 types by their nullable
properties:
- AlwaysNullable
- AlwaysNotNullable
- NullableAggregateFunction

this PR code of FE part from #37215,
the author is @starocean999
---------
Co-authored-by: starocean999 <12095047@qq.com>
  • Loading branch information
zhangstar333 authored Jul 15, 2024
1 parent 6492ee6 commit 0816969
Show file tree
Hide file tree
Showing 47 changed files with 4,179 additions and 807 deletions.
6 changes: 4 additions & 2 deletions be/src/agent/be_exec_version_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,15 @@ class BeExecVersionManager {
* a. change the impl of percentile (need fix)
* b. clear old version of version 3->4
* c. change FunctionIsIPAddressInRange from AlwaysNotNullable to DependOnArguments
* d. change some agg function nullable property: PR #37215
*/
constexpr inline int BeExecVersionManager::max_be_exec_version = 5;
constexpr inline int BeExecVersionManager::min_be_exec_version = 0;

/// functional
constexpr inline int BITMAP_SERDE = 3;
constexpr inline int USE_NEW_SERDE = 4; // release on DORIS version 2.1
constexpr inline int OLD_WAL_SERDE = 3; // use to solve compatibility issues, see pr #32299
constexpr inline int USE_NEW_SERDE = 4; // release on DORIS version 2.1
constexpr inline int OLD_WAL_SERDE = 3; // use to solve compatibility issues, see pr #32299
constexpr inline int AGG_FUNCTION_NULLABLE = 5; // change some agg nullable property: PR #37215

} // namespace doris
2 changes: 1 addition & 1 deletion be/src/pipeline/exec/aggregation_sink_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,7 @@ Status AggSinkOperatorX::prepare(RuntimeState* state) {
RETURN_IF_ERROR(_aggregate_evaluators[i]->prepare(
state, DataSinkOperatorX<AggSinkLocalState>::_child_x->row_desc(),
intermediate_slot_desc, output_slot_desc));
_aggregate_evaluators[i]->set_version(state->be_exec_version());
}

_offsets_of_aggregate_states.resize(_aggregate_evaluators.size());
Expand Down Expand Up @@ -832,7 +833,6 @@ Status AggSinkOperatorX::open(RuntimeState* state) {

for (auto& _aggregate_evaluator : _aggregate_evaluators) {
RETURN_IF_ERROR(_aggregate_evaluator->open(state));
_aggregate_evaluator->set_version(state->be_exec_version());
}

return Status::OK();
Expand Down
1 change: 1 addition & 0 deletions be/src/pipeline/exec/analytic_source_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ Status AnalyticSourceOperatorX::prepare(RuntimeState* state) {
SlotDescriptor* output_slot_desc = _output_tuple_desc->slots()[i];
RETURN_IF_ERROR(_agg_functions[i]->prepare(state, _child_x->row_desc(),
intermediate_slot_desc, output_slot_desc));
_agg_functions[i]->set_version(state->be_exec_version());
_change_to_nullable_flags.push_back(output_slot_desc->is_nullable() &&
!_agg_functions[i]->data_type()->is_nullable());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ Status DistinctStreamingAggOperatorX::prepare(RuntimeState* state) {
SlotDescriptor* output_slot_desc = _output_tuple_desc->slots()[j];
RETURN_IF_ERROR(_aggregate_evaluators[i]->prepare(
state, _child_x->row_desc(), intermediate_slot_desc, output_slot_desc));
_aggregate_evaluators[i]->set_version(state->be_exec_version());
}

for (size_t i = 0; i < _aggregate_evaluators.size(); ++i) {
Expand Down Expand Up @@ -421,7 +422,6 @@ Status DistinctStreamingAggOperatorX::open(RuntimeState* state) {

for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
RETURN_IF_ERROR(_aggregate_evaluators[i]->open(state));
_aggregate_evaluators[i]->set_version(state->be_exec_version());
}

return Status::OK();
Expand Down
2 changes: 1 addition & 1 deletion be/src/pipeline/exec/streaming_aggregation_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1198,6 +1198,7 @@ Status StreamingAggOperatorX::prepare(RuntimeState* state) {
SlotDescriptor* output_slot_desc = _output_tuple_desc->slots()[j];
RETURN_IF_ERROR(_aggregate_evaluators[i]->prepare(
state, _child_x->row_desc(), intermediate_slot_desc, output_slot_desc));
_aggregate_evaluators[i]->set_version(state->be_exec_version());
}

_offsets_of_aggregate_states.resize(_aggregate_evaluators.size());
Expand Down Expand Up @@ -1239,7 +1240,6 @@ Status StreamingAggOperatorX::open(RuntimeState* state) {

for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
RETURN_IF_ERROR(_aggregate_evaluators[i]->open(state));
_aggregate_evaluators[i]->set_version(state->be_exec_version());
}

return Status::OK();
Expand Down
25 changes: 19 additions & 6 deletions be/src/vec/aggregate_functions/aggregate_function_covar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,19 @@ AggregateFunctionPtr create_function_single_value(const String& name,
}

template <bool is_nullable>
AggregateFunctionPtr create_aggregate_function_covariance_samp_old(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
return create_function_single_value<AggregateFunctionSamp_OLDER, CovarSampName, SampData_OLDER,
is_nullable>(name, argument_types, result_is_nullable,
NULLABLE);
}

AggregateFunctionPtr create_aggregate_function_covariance_samp(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
return create_function_single_value<AggregateFunctionSamp, CovarSampName, SampData,
is_nullable>(name, argument_types, result_is_nullable,
NULLABLE);
return create_function_single_value<AggregateFunctionSamp, CovarSampName, SampData>(
name, argument_types, result_is_nullable, NOTNULLABLE);
}

AggregateFunctionPtr create_aggregate_function_covariance_pop(const std::string& name,
Expand All @@ -81,9 +88,15 @@ void register_aggregate_function_covar_pop(AggregateFunctionSimpleFactory& facto
factory.register_alias("covar", "covar_pop");
}

void register_aggregate_function_covar_samp_old(AggregateFunctionSimpleFactory& factory) {
factory.register_alternative_function(
"covar_samp", create_aggregate_function_covariance_samp_old<NOTNULLABLE>);
factory.register_alternative_function(
"covar_samp", create_aggregate_function_covariance_samp_old<NULLABLE>, NULLABLE);
}

void register_aggregate_function_covar_samp(AggregateFunctionSimpleFactory& factory) {
factory.register_function("covar_samp", create_aggregate_function_covariance_samp<NOTNULLABLE>);
factory.register_function("covar_samp", create_aggregate_function_covariance_samp<NULLABLE>,
NULLABLE);
factory.register_function_both("covar_samp", create_aggregate_function_covariance_samp);
register_aggregate_function_covar_samp_old(factory);
}
} // namespace doris::vectorized
37 changes: 34 additions & 3 deletions be/src/vec/aggregate_functions/aggregate_function_covar.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#pragma once

#include "agent/be_exec_version_manager.h"
#define POP true
#define NOTPOP false
#define NULLABLE true
Expand Down Expand Up @@ -224,7 +225,7 @@ struct PopData : Data {
};

template <typename T, typename Data>
struct SampData : Data {
struct SampData_OLDER : Data {
using ColVecResult =
std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<Decimal128V2>, ColumnFloat64>;
void insert_result_into(IColumn& to) const {
Expand All @@ -243,6 +244,24 @@ struct SampData : Data {
}
};

template <typename T, typename Data>
struct SampData : Data {
using ColVecResult =
std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<Decimal128V2>, ColumnFloat64>;
void insert_result_into(IColumn& to) const {
auto& col = assert_cast<ColVecResult&>(to);
if (this->count == 1 || this->count == 0) {
col.insert_default();
} else {
if constexpr (IsDecimalNumber<T>) {
col.get_data().push_back(this->get_samp_result().value());
} else {
col.get_data().push_back(this->get_samp_result());
}
}
}
};

template <typename Data>
struct CovarName : Data {
static const char* name() { return "covar"; }
Expand All @@ -269,7 +288,11 @@ class AggregateFunctionSampCovariance
if constexpr (is_pop) {
return Data::get_return_type();
} else {
return make_nullable(Data::get_return_type());
if (IAggregateFunction::version < AGG_FUNCTION_NULLABLE) {
return make_nullable(Data::get_return_type());
} else {
return Data::get_return_type();
}
}
}

Expand All @@ -278,7 +301,7 @@ class AggregateFunctionSampCovariance
if constexpr (is_pop) {
this->data(place).add(columns[0], columns[1], row_num);
} else {
if constexpr (is_nullable) {
if constexpr (is_nullable) { //this if check could remove with old function
const auto* nullable_column_x = check_and_get_column<ColumnNullable>(columns[0]);
const auto* nullable_column_y = check_and_get_column<ColumnNullable>(columns[1]);
if (!nullable_column_x->is_null_at(row_num) &&
Expand Down Expand Up @@ -313,6 +336,14 @@ class AggregateFunctionSampCovariance
}
};

template <typename Data, bool is_nullable>
class AggregateFunctionSamp_OLDER final
: public AggregateFunctionSampCovariance<NOTPOP, Data, is_nullable> {
public:
AggregateFunctionSamp_OLDER(const DataTypes& argument_types_)
: AggregateFunctionSampCovariance<NOTPOP, Data, is_nullable>(argument_types_) {}
};

template <typename Data, bool is_nullable>
class AggregateFunctionSamp final
: public AggregateFunctionSampCovariance<NOTPOP, Data, is_nullable> {
Expand Down
82 changes: 65 additions & 17 deletions be/src/vec/aggregate_functions/aggregate_function_percentile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,47 @@
namespace doris::vectorized {

template <bool is_nullable>
AggregateFunctionPtr create_aggregate_function_percentile_approx(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
AggregateFunctionPtr create_aggregate_function_percentile_approx_older(
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) {
const DataTypePtr& argument_type = remove_nullable(argument_types[0]);
WhichDataType which(argument_type);
if (which.idx != TypeIndex::Float64) {
return nullptr;
}
if (argument_types.size() == 2) {
return creator_without_type::create<
AggregateFunctionPercentileApproxTwoParams<is_nullable>>(
remove_nullable(argument_types), result_is_nullable);
AggregateFunctionPercentileApproxTwoParams_OLDER<is_nullable>>((argument_types),
result_is_nullable);
}
if (argument_types.size() == 3) {
return creator_without_type::create<
AggregateFunctionPercentileApproxThreeParams<is_nullable>>(
AggregateFunctionPercentileApproxThreeParams_OLDER<is_nullable>>(
remove_nullable(argument_types), result_is_nullable);
}
return nullptr;
}

AggregateFunctionPtr create_aggregate_function_percentile_approx(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const DataTypePtr& argument_type = remove_nullable(argument_types[0]);
WhichDataType which(argument_type);
if (which.idx != TypeIndex::Float64) {
return nullptr;
}
if (argument_types.size() == 2) {
return creator_without_type::create<AggregateFunctionPercentileApproxTwoParams>(
argument_types, result_is_nullable);
}
if (argument_types.size() == 3) {
return creator_without_type::create<AggregateFunctionPercentileApproxThreeParams>(
argument_types, result_is_nullable);
}
return nullptr;
}

template <bool is_nullable>
AggregateFunctionPtr create_aggregate_function_percentile_approx_weighted(
AggregateFunctionPtr create_aggregate_function_percentile_approx_weighted_older(
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) {
const DataTypePtr& argument_type = remove_nullable(argument_types[0]);
WhichDataType which(argument_type);
Expand All @@ -55,17 +73,35 @@ AggregateFunctionPtr create_aggregate_function_percentile_approx_weighted(
}
if (argument_types.size() == 3) {
return creator_without_type::create<
AggregateFunctionPercentileApproxWeightedThreeParams<is_nullable>>(
AggregateFunctionPercentileApproxWeightedThreeParams_OLDER<is_nullable>>(
remove_nullable(argument_types), result_is_nullable);
}
if (argument_types.size() == 4) {
return creator_without_type::create<
AggregateFunctionPercentileApproxWeightedFourParams<is_nullable>>(
AggregateFunctionPercentileApproxWeightedFourParams_OLDER<is_nullable>>(
remove_nullable(argument_types), result_is_nullable);
}
return nullptr;
}

AggregateFunctionPtr create_aggregate_function_percentile_approx_weighted(
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) {
const DataTypePtr& argument_type = remove_nullable(argument_types[0]);
WhichDataType which(argument_type);
if (which.idx != TypeIndex::Float64) {
return nullptr;
}
if (argument_types.size() == 3) {
return creator_without_type::create<AggregateFunctionPercentileApproxWeightedThreeParams>(
argument_types, result_is_nullable);
}
if (argument_types.size() == 4) {
return creator_without_type::create<AggregateFunctionPercentileApproxWeightedFourParams>(
argument_types, result_is_nullable);
}
return nullptr;
}

void register_aggregate_function_percentile(AggregateFunctionSimpleFactory& factory) {
factory.register_function_both("percentile",
creator_with_integer_type::creator<AggregateFunctionPercentile>);
Expand All @@ -74,14 +110,26 @@ void register_aggregate_function_percentile(AggregateFunctionSimpleFactory& fact
creator_with_integer_type::creator<AggregateFunctionPercentileArray>);
}

void register_percentile_approx_old_function(AggregateFunctionSimpleFactory& factory) {
factory.register_alternative_function(
"percentile_approx", create_aggregate_function_percentile_approx_older<false>, false);
factory.register_alternative_function(
"percentile_approx", create_aggregate_function_percentile_approx_older<true>, true);
factory.register_alternative_function(
"percentile_approx_weighted",
create_aggregate_function_percentile_approx_weighted_older<false>, false);
factory.register_alternative_function(
"percentile_approx_weighted",
create_aggregate_function_percentile_approx_weighted_older<true>, true);
}

void register_aggregate_function_percentile_approx(AggregateFunctionSimpleFactory& factory) {
factory.register_function("percentile_approx",
create_aggregate_function_percentile_approx<false>, false);
factory.register_function("percentile_approx",
create_aggregate_function_percentile_approx<true>, true);
factory.register_function("percentile_approx_weighted",
create_aggregate_function_percentile_approx_weighted<false>, false);
factory.register_function("percentile_approx_weighted",
create_aggregate_function_percentile_approx_weighted<true>, true);
factory.register_function_both("percentile_approx",
create_aggregate_function_percentile_approx);
factory.register_function_both("percentile_approx_weighted",
create_aggregate_function_percentile_approx_weighted);

register_percentile_approx_old_function(factory);
}

} // namespace doris::vectorized
Loading

0 comments on commit 0816969

Please sign in to comment.