diff --git a/core/src/main/scala/com/intel/oap/ColumnarPluginConfig.scala b/core/src/main/scala/com/intel/oap/ColumnarPluginConfig.scala index 9ca884e23..3cbc513cf 100644 --- a/core/src/main/scala/com/intel/oap/ColumnarPluginConfig.scala +++ b/core/src/main/scala/com/intel/oap/ColumnarPluginConfig.scala @@ -28,6 +28,8 @@ case class ColumnarNumaBindingInfo( class ColumnarPluginConfig(conf: SQLConf) { val enableColumnarSort: Boolean = conf.getConfString("spark.sql.columnar.sort", "false").toBoolean + val enableColumnarCodegenSort: Boolean = + conf.getConfString("spark.sql.columnar.codegen.sort", "true").toBoolean val enableColumnarNaNCheck: Boolean = conf.getConfString("spark.sql.columnar.nanCheck", "false").toBoolean val enableColumnarBroadcastJoin: Boolean = diff --git a/core/src/main/scala/com/intel/oap/expression/ColumnarSorter.scala b/core/src/main/scala/com/intel/oap/expression/ColumnarSorter.scala index 5afa6225d..54917ce01 100644 --- a/core/src/main/scala/com/intel/oap/expression/ColumnarSorter.scala +++ b/core/src/main/scala/com/intel/oap/expression/ColumnarSorter.scala @@ -233,6 +233,7 @@ object ColumnarSorter extends Logging { result_type: Int = 0): TreeNode = { logInfo(s"ColumnarSorter sortOrder is ${sortOrder}, outputAttributes is ${outputAttributes}") val NaNCheck = ColumnarPluginConfig.getConf.enableColumnarNaNCheck + val codegen = ColumnarPluginConfig.getConf.enableColumnarCodegenSort /////////////// Prepare ColumnarSorter ////////////// val outputFieldList: List[Field] = outputAttributes.toList.map(expr => { val attr = ConverterUtils.getAttrFromExpr(expr) @@ -322,6 +323,11 @@ object ColumnarSorter extends Logging { TreeBuilder.makeLiteral(NaNCheck.asInstanceOf[java.lang.Boolean])), new ArrowType.Int(32, true) /*dummy ret type, won't be used*/ ) + val codegen_node = TreeBuilder.makeFunction( + "codegen", + Lists.newArrayList(TreeBuilder.makeLiteral(codegen.asInstanceOf[java.lang.Boolean])), + new ArrowType.Int(32, true) /*dummy ret type, won't be used*/ ) + val result_type_node = TreeBuilder.makeFunction( "result_type", Lists.newArrayList(TreeBuilder.makeLiteral(result_type.asInstanceOf[Integer])), @@ -337,6 +343,7 @@ object ColumnarSorter extends Logging { dir_node, nulls_order_node, NaN_check_node, + codegen_node, result_type_node), new ArrowType.Int(32, true) /*dummy ret type, won't be used*/ ) diff --git a/cpp/src/codegen/arrow_compute/expr_visitor_impl.h b/cpp/src/codegen/arrow_compute/expr_visitor_impl.h index 73071858c..74f63acdb 100644 --- a/cpp/src/codegen/arrow_compute/expr_visitor_impl.h +++ b/cpp/src/codegen/arrow_compute/expr_visitor_impl.h @@ -630,14 +630,19 @@ class SortArraysToIndicesVisitorImpl : public ExprVisitorImpl { nulls_order_.push_back(order_val); } // fifth child specifies whether to check NaN when sorting - auto function_node = std::dynamic_pointer_cast(children[4]); - auto NaN_check_node = - std::dynamic_pointer_cast(function_node->children()[0]); - NaN_check_ = arrow::util::get(NaN_check_node->holder()); - - if (children.size() == 6) { + auto nan_func_node = std::dynamic_pointer_cast(children[4]); + auto NaN_lit_node = + std::dynamic_pointer_cast(nan_func_node->children()[0]); + NaN_check_ = arrow::util::get(NaN_lit_node->holder()); + // sixth child specifies whether to do codegen for mutiple-key sort + auto codegen_func_node = + std::dynamic_pointer_cast(children[5]); + auto codegen_lit_node = + std::dynamic_pointer_cast(codegen_func_node->children()[0]); + do_codegen_ = arrow::util::get(codegen_lit_node->holder()); + if (children.size() == 7) { auto type_node = std::dynamic_pointer_cast( - std::dynamic_pointer_cast(children[5])->children()[0]); + std::dynamic_pointer_cast(children[6])->children()[0]); result_type_ = arrow::util::get(type_node->holder()); } result_schema_ = arrow::schema(ret_fields); @@ -659,7 +664,7 @@ class SortArraysToIndicesVisitorImpl : public ExprVisitorImpl { } RETURN_NOT_OK(extra::SortArraysToIndicesKernel::Make( &p_->ctx_, result_schema_, sort_key_node_, key_field_list_, sort_directions_, - nulls_order_, NaN_check_, result_type_, &kernel_)); + nulls_order_, NaN_check_, do_codegen_, result_type_, &kernel_)); p_->signature_ = kernel_->GetSignature(); initialized_ = true; finish_return_type_ = ArrowComputeResultType::BatchIterator; @@ -711,6 +716,7 @@ class SortArraysToIndicesVisitorImpl : public ExprVisitorImpl { std::vector sort_directions_; std::vector nulls_order_; bool NaN_check_; + bool do_codegen_; int result_type_ = 0; std::shared_ptr result_schema_; }; diff --git a/cpp/src/codegen/arrow_compute/ext/cmp_function.h b/cpp/src/codegen/arrow_compute/ext/cmp_function.h new file mode 100644 index 000000000..a3e35102e --- /dev/null +++ b/cpp/src/codegen/arrow_compute/ext/cmp_function.h @@ -0,0 +1,668 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include "third_party/function.h" + +namespace sparkcolumnarplugin { +namespace codegen { +namespace arrowcompute { +namespace extra { + +template +class TypedComparator { + public: + TypedComparator() {} + + ~TypedComparator() {} + + func::function GetCompareFunc( + const arrow::ArrayVector& arrays, bool asc, bool nulls_first) { + uint64_t null_total = 0; + std::vector> typed_arrays; + for (int array_id = 0; array_id < arrays.size(); array_id++) { + null_total += arrays[array_id]->null_count(); + auto typed_array = std::dynamic_pointer_cast(arrays[array_id]); + typed_arrays.push_back(typed_array); + } + if (null_total == 0) { + if (asc) { + return [=](int left_array_id, int right_array_id, + int64_t left_id, int64_t right_id, int& cmp_res) { + CType left = typed_arrays[left_array_id]->GetView(left_id); + CType right = typed_arrays[right_array_id]->GetView(right_id); + if (left != right) { + cmp_res = left < right; + } + }; + } else { + return [=](int left_array_id, int right_array_id, + int64_t left_id, int64_t right_id, int& cmp_res) { + CType left = typed_arrays[left_array_id]->GetView(left_id); + CType right = typed_arrays[right_array_id]->GetView(right_id); + if (left != right) { + cmp_res = left > right; + } + }; + } + } else if (asc) { + if (nulls_first) { + return [=](int left_array_id, int right_array_id, + int64_t left_id, int64_t right_id, int& cmp_res) { + bool is_left_null = typed_arrays[left_array_id]->null_count() > 0 && + typed_arrays[left_array_id]->IsNull(left_id); + bool is_right_null = typed_arrays[right_array_id]->null_count() > 0 && + typed_arrays[right_array_id]->IsNull(right_id); + if (!is_left_null || !is_right_null) { + if (is_left_null) { + cmp_res = 1; + } else if (is_right_null) { + cmp_res = 0; + } else { + CType left = typed_arrays[left_array_id]->GetView(left_id); + CType right = typed_arrays[right_array_id]->GetView(right_id); + if (left != right) { + cmp_res = left < right; + } + } + } + }; + } else { + return [=](int left_array_id, int right_array_id, + int64_t left_id, int64_t right_id, int& cmp_res) { + bool is_left_null = typed_arrays[left_array_id]->null_count() > 0 && + typed_arrays[left_array_id]->IsNull(left_id); + bool is_right_null = typed_arrays[right_array_id]->null_count() > 0 && + typed_arrays[right_array_id]->IsNull(right_id); + if (!is_left_null || !is_right_null) { + if (is_left_null) { + cmp_res = 0; + } else if (is_right_null) { + cmp_res = 1; + } else { + CType left = typed_arrays[left_array_id]->GetView(left_id); + CType right = typed_arrays[right_array_id]->GetView(right_id); + if (left != right) { + cmp_res = left < right; + } + } + } + }; + } + } else if (nulls_first) { + return [=](int left_array_id, int right_array_id, + int64_t left_id, int64_t right_id, int& cmp_res) { + bool is_left_null = typed_arrays[left_array_id]->null_count() > 0 && + typed_arrays[left_array_id]->IsNull(left_id); + bool is_right_null = typed_arrays[right_array_id]->null_count() > 0 && + typed_arrays[right_array_id]->IsNull(right_id); + if (!is_left_null || !is_right_null) { + if (is_left_null) { + cmp_res = 1; + } else if (is_right_null) { + cmp_res = 0; + } else { + CType left = typed_arrays[left_array_id]->GetView(left_id); + CType right = typed_arrays[right_array_id]->GetView(right_id); + if (left != right) { + cmp_res = left > right; + } + } + } + }; + } else { + return [=](int left_array_id, int right_array_id, + int64_t left_id, int64_t right_id, int& cmp_res) { + bool is_left_null = typed_arrays[left_array_id]->null_count() > 0 && + typed_arrays[left_array_id]->IsNull(left_id); + bool is_right_null = typed_arrays[right_array_id]->null_count() > 0 && + typed_arrays[right_array_id]->IsNull(right_id); + if (!is_left_null || !is_right_null) { + if (is_left_null) { + cmp_res = 0; + } else if (is_right_null) { + cmp_res = 1; + } else { + CType left = typed_arrays[left_array_id]->GetView(left_id); + CType right = typed_arrays[right_array_id]->GetView(right_id); + if (left != right) { + cmp_res = left > right; + } + } + } + }; + } + } + + private: + using ArrayType = typename arrow::TypeTraits::ArrayType; +}; + +template +class FloatingComparator { + public: + FloatingComparator() {} + + ~FloatingComparator() {} + + func::function GetCompareFunc( + const arrow::ArrayVector& arrays, bool asc, bool nulls_first, bool nan_check) { + uint64_t null_total = 0; + std::vector> typed_arrays; + for (int array_id = 0; array_id < arrays.size(); array_id++) { + null_total += arrays[array_id]->null_count(); + auto typed_array = std::dynamic_pointer_cast(arrays[array_id]); + typed_arrays.push_back(typed_array); + } + if (null_total == 0) { + if (asc) { + if (nan_check) { + // null_total == 0, asc, nan_check + return [=](int left_array_id, int right_array_id, + int64_t left_id, int64_t right_id, int& cmp_res) { + CType left = typed_arrays[left_array_id]->GetView(left_id); + CType right = typed_arrays[right_array_id]->GetView(right_id); + bool is_left_nan = std::isnan(left); + bool is_right_nan = std::isnan(right); + if (!is_left_nan || !is_right_nan) { + if (is_left_nan) { + cmp_res = 0; + } else if (is_right_nan) { + cmp_res = 1; + } else { + if (left != right) { + cmp_res = left < right; + } + } + } + }; + } else { + // null_total == 0, asc, !nan_check + return [=](int left_array_id, int right_array_id, + int64_t left_id, int64_t right_id, int& cmp_res) { + CType left = typed_arrays[left_array_id]->GetView(left_id); + CType right = typed_arrays[right_array_id]->GetView(right_id); + if (left != right) { + cmp_res = left < right; + } + }; + } + } else { + if (nan_check) { + // null_total == 0, desc, nan_check + return [=](int left_array_id, int right_array_id, + int64_t left_id, int64_t right_id, int& cmp_res) { + CType left = typed_arrays[left_array_id]->GetView(left_id); + CType right = typed_arrays[right_array_id]->GetView(right_id); + bool is_left_nan = std::isnan(left); + bool is_right_nan = std::isnan(right); + if (!is_left_nan || !is_right_nan) { + if (is_left_nan) { + cmp_res = 1; + } else if (is_right_nan) { + cmp_res = 0; + } else { + if (left != right) { + cmp_res = left > right; + } + } + } + }; + } else { + // null_total == 0, desc, !nan_check + return [=](int left_array_id, int right_array_id, + int64_t left_id, int64_t right_id, int& cmp_res) { + CType left = typed_arrays[left_array_id]->GetView(left_id); + CType right = typed_arrays[right_array_id]->GetView(right_id); + if (left != right) { + cmp_res = left > right; + } + }; + } + } + } else if (asc) { + if (nulls_first) { + if (nan_check) { + // nulls_first, asc, nan_check + return [=](int left_array_id, int right_array_id, + int64_t left_id, int64_t right_id, int& cmp_res) { + bool is_left_null = typed_arrays[left_array_id]->null_count() > 0 && + typed_arrays[left_array_id]->IsNull(left_id); + bool is_right_null = typed_arrays[right_array_id]->null_count() > 0 && + typed_arrays[right_array_id]->IsNull(right_id); + if (!is_left_null || !is_right_null) { + if (is_left_null) { + cmp_res = 1; + } else if (is_right_null) { + cmp_res = 0; + } else { + CType left = typed_arrays[left_array_id]->GetView(left_id); + CType right = typed_arrays[right_array_id]->GetView(right_id); + bool is_left_nan = std::isnan(left); + bool is_right_nan = std::isnan(right); + if (!is_left_nan || !is_right_nan) { + if (is_left_nan) { + cmp_res = 0; + } else if (is_right_nan) { + cmp_res = 1; + } else { + if (left != right) { + cmp_res = left < right; + } + } + } + } + } + }; + } else { + // nulls_first, asc, !nan_check + return [=](int left_array_id, int right_array_id, + int64_t left_id, int64_t right_id, int& cmp_res) { + bool is_left_null = typed_arrays[left_array_id]->null_count() > 0 && + typed_arrays[left_array_id]->IsNull(left_id); + bool is_right_null = typed_arrays[right_array_id]->null_count() > 0 && + typed_arrays[right_array_id]->IsNull(right_id); + if (!is_left_null || !is_right_null) { + if (is_left_null) { + cmp_res = 1; + } else if (is_right_null) { + cmp_res = 0; + } else { + CType left = typed_arrays[left_array_id]->GetView(left_id); + CType right = typed_arrays[right_array_id]->GetView(right_id); + if (left != right) { + cmp_res = left < right; + } + } + } + }; + } + } else { + if (nan_check) { + // nulls_last, asc, nan_check + return [=](int left_array_id, int right_array_id, + int64_t left_id, int64_t right_id, int& cmp_res) { + bool is_left_null = typed_arrays[left_array_id]->null_count() > 0 && + typed_arrays[left_array_id]->IsNull(left_id); + bool is_right_null = typed_arrays[right_array_id]->null_count() > 0 && + typed_arrays[right_array_id]->IsNull(right_id); + if (!is_left_null || !is_right_null) { + if (is_left_null) { + cmp_res = 0; + } else if (is_right_null) { + cmp_res = 1; + } else { + CType left = typed_arrays[left_array_id]->GetView(left_id); + CType right = typed_arrays[right_array_id]->GetView(right_id); + bool is_left_nan = std::isnan(left); + bool is_right_nan = std::isnan(right); + if (!is_left_nan || !is_right_nan) { + if (is_left_nan) { + cmp_res = 0; + } else if (is_right_nan) { + cmp_res = 1; + } else { + if (left != right) { + cmp_res = left < right; + } + } + } + } + } + }; + } else { + // nulls_last, asc, !nan_check + return [=](int left_array_id, int right_array_id, + int64_t left_id, int64_t right_id, int& cmp_res) { + bool is_left_null = typed_arrays[left_array_id]->null_count() > 0 && + typed_arrays[left_array_id]->IsNull(left_id); + bool is_right_null = typed_arrays[right_array_id]->null_count() > 0 && + typed_arrays[right_array_id]->IsNull(right_id); + if (!is_left_null || !is_right_null) { + if (is_left_null) { + cmp_res = 0; + } else if (is_right_null) { + cmp_res = 1; + } else { + CType left = typed_arrays[left_array_id]->GetView(left_id); + CType right = typed_arrays[right_array_id]->GetView(right_id); + if (left != right) { + cmp_res = left < right; + } + } + } + }; + } + } + } else if (nulls_first) { + if (nan_check) { + // nulls_first, desc, nan_check + return [=](int left_array_id, int right_array_id, + int64_t left_id, int64_t right_id, int& cmp_res) { + bool is_left_null = typed_arrays[left_array_id]->null_count() > 0 && + typed_arrays[left_array_id]->IsNull(left_id); + bool is_right_null = typed_arrays[right_array_id]->null_count() > 0 && + typed_arrays[right_array_id]->IsNull(right_id); + if (!is_left_null || !is_right_null) { + if (is_left_null) { + cmp_res = 1; + } else if (is_right_null) { + cmp_res = 0; + } else { + CType left = typed_arrays[left_array_id]->GetView(left_id); + CType right = typed_arrays[right_array_id]->GetView(right_id); + bool is_left_nan = std::isnan(left); + bool is_right_nan = std::isnan(right); + if (!is_left_nan || !is_right_nan) { + if (is_left_nan) { + cmp_res = 1; + } else if (is_right_nan) { + cmp_res = 0; + } else { + if (left != right) { + cmp_res = left > right; + } + } + } + } + } + }; + } else { + // nulls_first, desc, !nan_check + return [=](int left_array_id, int right_array_id, + int64_t left_id, int64_t right_id, int& cmp_res) { + bool is_left_null = typed_arrays[left_array_id]->null_count() > 0 && + typed_arrays[left_array_id]->IsNull(left_id); + bool is_right_null = typed_arrays[right_array_id]->null_count() > 0 && + typed_arrays[right_array_id]->IsNull(right_id); + if (!is_left_null || !is_right_null) { + if (is_left_null) { + cmp_res = 1; + } else if (is_right_null) { + cmp_res = 0; + } else { + CType left = typed_arrays[left_array_id]->GetView(left_id); + CType right = typed_arrays[right_array_id]->GetView(right_id); + if (left != right) { + cmp_res = left > right; + } + } + } + }; + } + } else { + if (nan_check) { + // nulls_last, desc, nan_check + return [=](int left_array_id, int right_array_id, + int64_t left_id, int64_t right_id, int& cmp_res) { + bool is_left_null = typed_arrays[left_array_id]->null_count() > 0 && + typed_arrays[left_array_id]->IsNull(left_id); + bool is_right_null = typed_arrays[right_array_id]->null_count() > 0 && + typed_arrays[right_array_id]->IsNull(right_id); + if (!is_left_null || !is_right_null) { + if (is_left_null) { + cmp_res = 0; + } else if (is_right_null) { + cmp_res = 1; + } else { + CType left = typed_arrays[left_array_id]->GetView(left_id); + CType right = typed_arrays[right_array_id]->GetView(right_id); + bool is_left_nan = std::isnan(left); + bool is_right_nan = std::isnan(right); + if (!is_left_nan || !is_right_nan) { + if (is_left_nan) { + cmp_res = 1; + } else if (is_right_nan) { + cmp_res = 0; + } else { + if (left != right) { + cmp_res = left > right; + } + } + } + } + } + }; + } else { + // nulls_last, desc, !nan_check + return [=](int left_array_id, int right_array_id, + int64_t left_id, int64_t right_id, int& cmp_res) { + bool is_left_null = typed_arrays[left_array_id]->null_count() > 0 && + typed_arrays[left_array_id]->IsNull(left_id); + bool is_right_null = typed_arrays[right_array_id]->null_count() > 0 && + typed_arrays[right_array_id]->IsNull(right_id); + if (!is_left_null || !is_right_null) { + if (is_left_null) { + cmp_res = 0; + } else if (is_right_null) { + cmp_res = 1; + } else { + CType left = typed_arrays[left_array_id]->GetView(left_id); + CType right = typed_arrays[right_array_id]->GetView(right_id); + if (left != right) { + cmp_res = left > right; + } + } + } + }; + } + } + } + + private: + using ArrayType = typename arrow::TypeTraits::ArrayType; +}; + +template +class StringComparator { + public: + StringComparator() {} + + ~StringComparator() {} + + func::function GetCompareFunc( + const arrow::ArrayVector& arrays, bool asc, bool nulls_first) { + uint64_t null_total = 0; + std::vector> typed_arrays; + for (int array_id = 0; array_id < arrays.size(); array_id++) { + null_total += arrays[array_id]->null_count(); + auto typed_array = std::dynamic_pointer_cast(arrays[array_id]); + typed_arrays.push_back(typed_array); + } + if (null_total == 0) { + if (asc) { + return [=](int left_array_id, int right_array_id, + int64_t left_id, int64_t right_id, int& cmp_res) { + CType left = typed_arrays[left_array_id]->GetString(left_id); + CType right = typed_arrays[right_array_id]->GetString(right_id); + if (left != right) { + cmp_res = left < right; + } + }; + } else { + return [=](int left_array_id, int right_array_id, + int64_t left_id, int64_t right_id, int& cmp_res) { + CType left = typed_arrays[left_array_id]->GetString(left_id); + CType right = typed_arrays[right_array_id]->GetString(right_id); + if (left != right) { + cmp_res = left > right; + } + }; + } + } else if (asc) { + if (nulls_first) { + return [=](int left_array_id, int right_array_id, + int64_t left_id, int64_t right_id, int& cmp_res) { + bool is_left_null = typed_arrays[left_array_id]->null_count() > 0 && + typed_arrays[left_array_id]->IsNull(left_id); + bool is_right_null = typed_arrays[right_array_id]->null_count() > 0 && + typed_arrays[right_array_id]->IsNull(right_id); + if (!is_left_null || !is_right_null) { + if (is_left_null) { + cmp_res = 1; + } else if (is_right_null) { + cmp_res = 0; + } else { + CType left = typed_arrays[left_array_id]->GetString(left_id); + CType right = typed_arrays[right_array_id]->GetString(right_id); + if (left != right) { + cmp_res = left < right; + } + } + } + }; + } else { + return [=](int left_array_id, int right_array_id, + int64_t left_id, int64_t right_id, int& cmp_res) { + bool is_left_null = typed_arrays[left_array_id]->null_count() > 0 && + typed_arrays[left_array_id]->IsNull(left_id); + bool is_right_null = typed_arrays[right_array_id]->null_count() > 0 && + typed_arrays[right_array_id]->IsNull(right_id); + if (!is_left_null || !is_right_null) { + if (is_left_null) { + cmp_res = 0; + } else if (is_right_null) { + cmp_res = 1; + } else { + CType left = typed_arrays[left_array_id]->GetString(left_id); + CType right = typed_arrays[right_array_id]->GetString(right_id); + if (left != right) { + cmp_res = left < right; + } + } + } + }; + } + } else if (nulls_first) { + return [=](int left_array_id, int right_array_id, + int64_t left_id, int64_t right_id, int& cmp_res) { + bool is_left_null = typed_arrays[left_array_id]->null_count() > 0 && + typed_arrays[left_array_id]->IsNull(left_id); + bool is_right_null = typed_arrays[right_array_id]->null_count() > 0 && + typed_arrays[right_array_id]->IsNull(right_id); + if (!is_left_null || !is_right_null) { + if (is_left_null) { + cmp_res = 1; + } else if (is_right_null) { + cmp_res = 0; + } else { + CType left = typed_arrays[left_array_id]->GetString(left_id); + CType right = typed_arrays[right_array_id]->GetString(right_id); + if (left != right) { + cmp_res = left > right; + } + } + } + }; + } else { + return [=](int left_array_id, int right_array_id, + int64_t left_id, int64_t right_id, int& cmp_res) { + bool is_left_null = typed_arrays[left_array_id]->null_count() > 0 && + typed_arrays[left_array_id]->IsNull(left_id); + bool is_right_null = typed_arrays[right_array_id]->null_count() > 0 && + typed_arrays[right_array_id]->IsNull(right_id); + if (!is_left_null || !is_right_null) { + if (is_left_null) { + cmp_res = 0; + } else if (is_right_null) { + cmp_res = 1; + } else { + CType left = typed_arrays[left_array_id]->GetString(left_id); + CType right = typed_arrays[right_array_id]->GetString(right_id); + if (left != right) { + cmp_res = left > right; + } + } + } + }; + } + } + + private: + using ArrayType = typename arrow::TypeTraits::ArrayType; +}; + +#define PROCESS_SUPPORTED_TYPES(PROCESS) \ + PROCESS(arrow::BooleanType) \ + PROCESS(arrow::UInt8Type) \ + PROCESS(arrow::Int8Type) \ + PROCESS(arrow::UInt16Type) \ + PROCESS(arrow::Int16Type) \ + PROCESS(arrow::UInt32Type) \ + PROCESS(arrow::Int32Type) \ + PROCESS(arrow::UInt64Type) \ + PROCESS(arrow::Int64Type) \ + PROCESS(arrow::Date32Type) \ + PROCESS(arrow::Date64Type) +static arrow::Status MakeCmpFunction( + const std::vector& array_vectors, + const std::vector>& key_field_list, + const std::vector& key_index_list, + const std::vector& sort_directions, + const std::vector& nulls_order, + const bool& nan_check, + std::vector>& cmp_functions) { + for (int i = 0; i < key_field_list.size(); i++) { + auto type = key_field_list[i]->type(); + int key_col_id = key_index_list[i]; + arrow::ArrayVector col = array_vectors[key_col_id]; + bool asc = sort_directions[i]; + bool nulls_first = nulls_order[i]; + if (type->id() == arrow::Type::STRING) { + auto comparator_ptr = + std::make_shared>(); + cmp_functions.push_back( + comparator_ptr->GetCompareFunc(col, asc, nulls_first)); + } else if (type->id() == arrow::Type::DOUBLE) { + auto comparator_ptr = + std::make_shared>(); + cmp_functions.push_back( + comparator_ptr->GetCompareFunc(col, asc, nulls_first, nan_check)); + } else if (type->id() == arrow::Type::FLOAT) { + auto comparator_ptr = + std::make_shared>(); + cmp_functions.push_back( + comparator_ptr->GetCompareFunc(col, asc, nulls_first, nan_check)); + } else { + switch (type->id()) { + #define PROCESS(InType) \ + case InType::type_id: { \ + using CType = typename arrow::TypeTraits::CType; \ + auto comparator_ptr = std::make_shared>(); \ + cmp_functions.push_back(comparator_ptr->GetCompareFunc(col, asc, nulls_first));\ + } break; + PROCESS_SUPPORTED_TYPES(PROCESS) + #undef PROCESS + default: { + std::cout << "MakeCmpFunction type not supported, type is " + << type << std::endl; + } break; + } + } + } + return arrow::Status::OK(); +} +#undef PROCESS_SUPPORTED_TYPES + + +} // namespace extra +} // namespace arrowcompute +} // namespace codegen +} // namespace sparkcolumnarplugin diff --git a/cpp/src/codegen/arrow_compute/ext/kernels_ext.h b/cpp/src/codegen/arrow_compute/ext/kernels_ext.h index 1fcd55f91..203ca5e09 100644 --- a/cpp/src/codegen/arrow_compute/ext/kernels_ext.h +++ b/cpp/src/codegen/arrow_compute/ext/kernels_ext.h @@ -319,14 +319,19 @@ class SortArraysToIndicesKernel : public KernalBase { gandiva::NodeVector sort_key_node, std::vector> key_field_list, std::vector sort_directions, - std::vector nulls_order, bool NaN_check, - int result_type, std::shared_ptr* out); + std::vector nulls_order, + bool NaN_check, + bool do_codegen, + int result_type, + std::shared_ptr* out); SortArraysToIndicesKernel(arrow::compute::FunctionContext* ctx, std::shared_ptr result_schema, gandiva::NodeVector sort_key_node, std::vector> key_field_list, std::vector sort_directions, - std::vector nulls_order, bool NaN_check, + std::vector nulls_order, + bool NaN_check, + bool do_codegen, int result_type); arrow::Status Evaluate(const ArrayList& in) override; arrow::Status MakeResultIterator( diff --git a/cpp/src/codegen/arrow_compute/ext/sort_kernel.cc b/cpp/src/codegen/arrow_compute/ext/sort_kernel.cc index 9e48ef85d..b8297dc9a 100644 --- a/cpp/src/codegen/arrow_compute/ext/sort_kernel.cc +++ b/cpp/src/codegen/arrow_compute/ext/sort_kernel.cc @@ -32,6 +32,7 @@ #include #include "array_appender.h" +#include "cmp_function.h" #include "codegen/arrow_compute/ext/array_item_index.h" #include "codegen/arrow_compute/ext/code_generator_base.h" #include "codegen/arrow_compute/ext/codegen_common.h" @@ -40,24 +41,27 @@ #include "precompile/array.h" #include "precompile/type.h" #include "third_party/ska_sort.hpp" +#include "third_party/timsort.hpp" #include "utils/macros.h" /** The Overall Implementation of Sort Kernel - * In general, there are three kenels to use when sorting for different data. - They are SortInplaceKernel, SortOnekeyKernel and SortArraysToIndicesKernel. + * In general, there are four kenels to use when sorting for different data. + They are SortInplaceKernel, SortOnekeyKernel, SortArraysToIndicesKernel + and SortMultiplekeyKernel. * If sorting for one non-string and non-bool col without payload, SortInplaceKernel - is used. In this kernel, if sorted data has no null value, ska_sort is used for - asc direciton, and std sort is used for desc direciton. If sorted data has null - value, arrow sort is used. Data is partitioned to null, NaN (for double and - float only) and valid value before sort. - * If sorting for one col with payload, and one string or bool col without payload, + is used. In this kernel, ska_sort is used for asc direciton, and std sort is used + for desc direciton. Data is partitioned to null, NaN (for double and float only) + and valid value before sort. + * If sorting for single key with payload, and one string or bool col without payload, SortOnekeyKernel is used. In this kernel, ska_sort is used for asc direciton, and std sort is used for desc direciton. Data is partitioned to null, NaN (for double and float only) and valid value before sort. - * If sorting for multiple cols, SortArraysToIndicesKernel is used. This kernel - will do codegen, and std sort is used. - * Projection is supported in all the three kernels. If projection is required, + * If sorting for multiple keys, there are two kernels to use. When enabling codegen, + * SortArraysToIndicesKernel is used, which will do codegen. When disabling codegen, + * SortMultiplekeyKernel is used, which uses std::function to do comparison. In both + * kernels, timsort is used. + * Projection is supported in all the four kernels. If projection is required, projection is completed before sort, and the projected cols are used to do comparison. FIXME: 1. datatype change after projection is not supported in Inplace. @@ -78,7 +82,8 @@ class SortArraysToIndicesKernel::Impl { std::shared_ptr key_projector, std::vector> projected_types, std::vector> key_field_list, - std::vector sort_directions, std::vector nulls_order, bool NaN_check) + std::vector sort_directions, std::vector nulls_order, + bool NaN_check) : ctx_(ctx), result_schema_(result_schema), key_projector_(key_projector), @@ -816,6 +821,8 @@ class SortInplaceKernel : public SortArraysToIndicesKernel::Impl { return arrow::Status::OK(); } + // This function is used for float/double data without null value. + // If NaN_check_ is true, we need to do partition for NaN before sort. template auto SortNoNull(TYPE* indices_begin, TYPE* indices_end) -> typename std::enable_if_t::value> { @@ -837,6 +844,7 @@ class SortInplaceKernel : public SortArraysToIndicesKernel::Impl { } } + // This function is used for non-float and non-double data without null value. template auto SortNoNull(TYPE* indices_begin, TYPE* indices_end) -> typename std::enable_if_t::value> { @@ -848,6 +856,8 @@ class SortInplaceKernel : public SortArraysToIndicesKernel::Impl { } } + // This function is used for float/double data with null value. + // We should do partition for null and NaN (if (NaN_check_ is true). template auto Sort(int64_t* indices_begin, int64_t* indices_end, const ArrayType& values) -> typename std::enable_if_t::value> { @@ -915,6 +925,8 @@ class SortInplaceKernel : public SortArraysToIndicesKernel::Impl { } } + // This function is used for non-float and non-double data with null value. + // We should do partition for null. template auto Sort(int64_t* indices_begin, int64_t* indices_end, const ArrayType& values) -> typename std::enable_if_t::value> { @@ -957,6 +969,7 @@ class SortInplaceKernel : public SortArraysToIndicesKernel::Impl { RETURN_NOT_OK( arrow::Concatenate(cached_0_, ctx_->memory_pool(), &concatenated_array_)); if (nulls_total_ > 0) { + // Function Sort is used. auto typed_array = std::dynamic_pointer_cast(concatenated_array_); std::shared_ptr indices_out; @@ -976,6 +989,7 @@ class SortInplaceKernel : public SortArraysToIndicesKernel::Impl { *out = std::make_shared(ctx_, schema, sort_out, nulls_first_, asc_); } else { + // Function SortNoNull is used. CTYPE* indices_begin = concatenated_array_->data()->GetMutableValues(1); CTYPE* indices_end = indices_begin + concatenated_array_->length(); @@ -1025,6 +1039,8 @@ class SortInplaceKernel : public SortArraysToIndicesKernel::Impl { return true; } + // This class is used to copy a piece of memory from the sorted ArrayData + // to a result array. template class SliceImpl { public: @@ -1151,8 +1167,7 @@ class SortInplaceKernel : public SortArraysToIndicesKernel::Impl { arrow::ArrayData result_data = *result_arr_->data(); arrow::ArrayData out_data; SliceImpl(result_data, ctx_->memory_pool(), length, total_offset_, - nulls_total_, nulls_first_, total_length_) - .Slice(&out_data); + nulls_total_, nulls_first_, total_length_).Slice(&out_data); std::shared_ptr out_0 = MakeArray(std::make_shared(std::move(out_data))); total_offset_ += length; @@ -1196,7 +1211,7 @@ class SortOnekeyKernel : public SortArraysToIndicesKernel::Impl { key_projector_(key_projector), NaN_check_(NaN_check) { #ifdef DEBUG - std::cout << "UseSortOneKeyForArithmetic" << std::endl; + std::cout << "UseSortOnekeyKernel" << std::endl; #endif auto indices = result_schema->GetAllFieldIndices(key_field_list[0]->name()); if (indices.size() < 1) { @@ -1574,15 +1589,275 @@ class SortOnekeyKernel : public SortArraysToIndicesKernel::Impl { }; }; +/////////////// SortArraysMultipleKeys //////////////// +class SortMultiplekeyKernel : public SortArraysToIndicesKernel::Impl { + public: + SortMultiplekeyKernel(arrow::compute::FunctionContext* ctx, + std::shared_ptr result_schema, + std::shared_ptr key_projector, + std::vector> projected_types, + std::vector> key_field_list, + std::vector sort_directions, + std::vector nulls_order, + bool NaN_check) + : ctx_(ctx), + nulls_order_(nulls_order), + sort_directions_(sort_directions), + result_schema_(result_schema), + key_projector_(key_projector), + key_field_list_(key_field_list), + NaN_check_(NaN_check) { + #ifdef DEBUG + std::cout << "UseSortMultiplekeyKernel" << std::endl; + #endif + for (auto field : key_field_list) { + auto indices = result_schema->GetAllFieldIndices(field->name()); + if (indices.size() != 1) { + std::cout << "[ERROR] SortArraysToIndicesKernel::Impl can't find key " + << field->ToString() << " from " << result_schema->ToString() + << std::endl; + throw; + } + key_index_list_.push_back(indices[0]); + } + col_num_ = result_schema->num_fields(); + int i = 0; + for (auto type : projected_types) { + auto field = arrow::field(std::to_string(i), type); + projected_field_list_.push_back(field); + i++; + } + } + ~SortMultiplekeyKernel(){} + + arrow::Status Evaluate(const ArrayList& in) override { + num_batches_++; + if (cached_.size() <= col_num_) { + cached_.resize(col_num_ + 1); + } + for (int i = 0; i < col_num_; i++) { + cached_[i].push_back(in[i]); + } + if (key_projector_) { + int projected_col_num = projected_field_list_.size(); + if (projected_.size() <= projected_col_num) { + projected_.resize(projected_col_num + 1); + } + std::vector> projected_batch; + // do projection here, and the projected arrays are used for comparison + auto length = in.size() > 0 ? in[0]->length() : 0; + auto in_batch = arrow::RecordBatch::Make(result_schema_, length, in); + RETURN_NOT_OK( + key_projector_->Evaluate(*in_batch, ctx_->memory_pool(), &projected_batch)); + for (int i = 0; i < projected_col_num; i++) { + std::shared_ptr col = projected_batch[i]; + projected_[i].push_back(col); + } + } + items_total_ += in[0]->length(); + length_list_.push_back(in[0]->length()); + return arrow::Status::OK(); + } + + int compareInternal(int left_array_id, int64_t left_id, int right_array_id, + int64_t right_id, int keys_num) { + int key_idx = 0; + while (key_idx < keys_num) { + // In comparison, 1 represents for true, 0 for false, and 2 for equal. + int cmp_res = 2; + cmp_functions_[key_idx](left_array_id, right_array_id, + left_id, right_id, cmp_res); + if (cmp_res != 2) { + return cmp_res; + } + key_idx += 1; + } + return 2; + } + + bool compareRow(int left_array_id, int64_t left_id, int right_array_id, + int64_t right_id, int keys_num) { + if (compareInternal(left_array_id, left_id, right_array_id, + right_id, keys_num) == 1) { + return true; + } + return false; + } + + auto Sort(ArrayItemIndexS* indices_begin, ArrayItemIndexS* indices_end) { + int keys_num = sort_directions_.size(); + auto comp = [this, &keys_num](ArrayItemIndexS x, ArrayItemIndexS y) { + return compareRow(x.array_id, x.id, y.array_id, y.id, keys_num);}; + gfx::timsort(indices_begin, indices_begin + items_total_, comp); + } + + void Partition(ArrayItemIndexS* indices_begin, + ArrayItemIndexS* indices_end) { + int64_t indices_i = 0; + int64_t indices_null = 0; + for (int array_id = 0; array_id < num_batches_; array_id++) { + for (int64_t i = 0; i < length_list_[array_id]; i++) { + (indices_begin + indices_i)->array_id = array_id; + (indices_begin + indices_i)->id = i; + indices_i++; + } + } + } + + arrow::Status FinishInternal(std::shared_ptr* out) { + // initiate buffer for all arrays + std::shared_ptr indices_buf; + int64_t buf_size = items_total_ * sizeof(ArrayItemIndexS); + RETURN_NOT_OK(arrow::AllocateBuffer(ctx_->memory_pool(), buf_size, &indices_buf)); + ArrayItemIndexS* indices_begin = + reinterpret_cast(indices_buf->mutable_data()); + ArrayItemIndexS* indices_end = indices_begin + items_total_; + // do partition and sort here + Partition(indices_begin, indices_end); + if (key_projector_) { + std::vector projected_key_idx_list; + for (int i = 0; i < projected_field_list_.size(); i++) { + projected_key_idx_list.push_back(i); + } + MakeCmpFunction( + projected_, projected_field_list_, projected_key_idx_list, sort_directions_, + nulls_order_, NaN_check_, cmp_functions_); + } else { + MakeCmpFunction( + cached_, key_field_list_, key_index_list_, sort_directions_, + nulls_order_, NaN_check_, cmp_functions_); + } + Sort(indices_begin, indices_end); + std::shared_ptr out_type; + RETURN_NOT_OK( + MakeFixedSizeBinaryType(sizeof(ArrayItemIndexS) / sizeof(int32_t), &out_type)); + RETURN_NOT_OK(MakeFixedSizeBinaryArray(out_type, items_total_, indices_buf, out)); + return arrow::Status::OK(); + } + + arrow::Status MakeResultIterator( + std::shared_ptr schema, + std::shared_ptr>* out) override { + std::shared_ptr indices_out; + RETURN_NOT_OK(FinishInternal(&indices_out)); + *out = std::make_shared(ctx_, schema, indices_out, cached_); + return arrow::Status::OK(); + } + + private: + std::vector cached_; + std::vector projected_; + arrow::compute::FunctionContext* ctx_; + std::shared_ptr result_schema_; + std::shared_ptr key_projector_; + std::vector> key_field_list_; + std::vector> projected_field_list_; + std::vector nulls_order_; + std::vector sort_directions_; + std::vector key_index_list_; + bool NaN_check_; + std::vector length_list_; + uint64_t num_batches_ = 0; + uint64_t items_total_ = 0; + int col_num_; + std::vector> cmp_functions_; + + class SorterResultIterator : public ResultIterator { + public: + SorterResultIterator(arrow::compute::FunctionContext* ctx, + std::shared_ptr schema, + std::shared_ptr indices_in, + std::vector& cached) + : ctx_(ctx), + schema_(schema), + indices_in_cache_(indices_in), + total_length_(indices_in->length()), + cached_in_(cached) { + col_num_ = schema->num_fields(); + indices_begin_ = (ArrayItemIndexS*)indices_in->value_data(); + // appender_type won't be used + AppenderBase::AppenderType appender_type = AppenderBase::left; + for (int i = 0; i < col_num_; i++) { + auto field = schema->field(i); + std::shared_ptr appender; + MakeAppender(ctx_, field->type(), appender_type, &appender); + appender_list_.push_back(appender); + } + for (int i = 0; i < col_num_; i++) { + arrow::ArrayVector array_vector = cached_in_[i]; + int array_num = array_vector.size(); + for (int array_id = 0; array_id < array_num; array_id++) { + auto arr = array_vector[array_id]; + appender_list_[i]->AddArray(arr); + } + } + batch_size_ = GetBatchSize(); + } + ~SorterResultIterator(){} + + std::string ToString() override { return "SortArraysToIndicesResultIterator"; } + + bool HasNext() override { + if (offset_ >= total_length_) { + return false; + } + return true; + } + + arrow::Status Next(std::shared_ptr* out) { + auto length = (total_length_ - offset_) > batch_size_ ? batch_size_ + : (total_length_ - offset_); + uint64_t count = 0; + for (int i = 0; i < col_num_; i++) { + while (count < length) { + auto item = indices_begin_ + offset_ + count++; + RETURN_NOT_OK(appender_list_[i]->Append(item->array_id, item->id)); + } + count = 0; + } + offset_ += length; + ArrayList arrays; + for (int i = 0; i < col_num_; i++) { + std::shared_ptr out_array; + RETURN_NOT_OK(appender_list_[i]->Finish(&out_array)); + arrays.push_back(out_array); + appender_list_[i]->Reset(); + } + + *out = arrow::RecordBatch::Make(schema_, length, arrays); + return arrow::Status::OK(); + } + + private: + uint64_t offset_ = 0; + const uint64_t total_length_; + std::shared_ptr schema_; + arrow::compute::FunctionContext* ctx_; + uint64_t batch_size_; + int col_num_; + ArrayItemIndexS* indices_begin_; + std::vector cached_in_; + std::vector> type_list_; + std::vector> appender_list_; + std::vector> array_list_; + std::shared_ptr indices_in_cache_; + }; +}; + arrow::Status SortArraysToIndicesKernel::Make( - arrow::compute::FunctionContext* ctx, std::shared_ptr result_schema, + arrow::compute::FunctionContext* ctx, + std::shared_ptr result_schema, gandiva::NodeVector sort_key_node, std::vector> key_field_list, - std::vector sort_directions, std::vector nulls_order, bool NaN_check, - int result_type, std::shared_ptr* out) { - *out = std::make_shared(ctx, result_schema, sort_key_node, - key_field_list, sort_directions, - nulls_order, NaN_check, result_type); + std::vector sort_directions, + std::vector nulls_order, + bool NaN_check, + bool do_codegen, + int result_type, + std::shared_ptr* out) { + *out = std::make_shared( + ctx, result_schema, sort_key_node, key_field_list, sort_directions, nulls_order, + NaN_check, do_codegen, result_type); return arrow::Status::OK(); } #define PROCESS_SUPPORTED_TYPES(PROCESS) \ @@ -1600,10 +1875,14 @@ arrow::Status SortArraysToIndicesKernel::Make( PROCESS(arrow::Date32Type) \ PROCESS(arrow::Date64Type) SortArraysToIndicesKernel::SortArraysToIndicesKernel( - arrow::compute::FunctionContext* ctx, std::shared_ptr result_schema, + arrow::compute::FunctionContext* ctx, + std::shared_ptr result_schema, gandiva::NodeVector sort_key_node, std::vector> key_field_list, - std::vector sort_directions, std::vector nulls_order, bool NaN_check, + std::vector sort_directions, + std::vector nulls_order, + bool NaN_check, + bool do_codegen, int result_type) { // sort_key_node may need to do projection bool pre_processed_key_ = false; @@ -1707,13 +1986,19 @@ SortArraysToIndicesKernel::SortArraysToIndicesKernel( } } } else { - // Will use Sort Codegen when sorting for several cols - impl_.reset(new Impl(ctx, result_schema, key_projector, projected_types, - key_field_list, sort_directions, nulls_order, NaN_check)); - auto status = impl_->LoadJITFunction(key_field_list, result_schema); - if (!status.ok()) { - std::cout << "LoadJITFunction failed, msg is " << status.message() << std::endl; - throw; + if (do_codegen) { + // Will use Sort Codegen for multiple-key sort + impl_.reset(new Impl(ctx, result_schema, key_projector, projected_types, + key_field_list, sort_directions, nulls_order, NaN_check)); + auto status = impl_->LoadJITFunction(key_field_list, result_schema); + if (!status.ok()) { + std::cout << "LoadJITFunction failed, msg is " << status.message() << std::endl; + throw; + } + } else { + // Will use Sort without Codegen for multiple-key sort + impl_.reset(new SortMultiplekeyKernel(ctx, result_schema, key_projector, + projected_types, key_field_list, sort_directions, nulls_order, NaN_check)); } } kernel_name_ = "SortArraysToIndicesKernel"; diff --git a/cpp/src/tests/arrow_compute_test_sort.cc b/cpp/src/tests/arrow_compute_test_sort.cc index d6d7413fc..f4cd1c825 100644 --- a/cpp/src/tests/arrow_compute_test_sort.cc +++ b/cpp/src/tests/arrow_compute_test_sort.cc @@ -50,8 +50,11 @@ TEST(TestArrowComputeSort, SortTestInPlaceNullsFirstAsc) { "sort_nulls_order", {true_literal}, uint32()); auto NaN_check = TreeExprBuilder::MakeFunction( "NaN_check", {true_literal}, uint32()); + auto do_codegen = TreeExprBuilder::MakeFunction( + "codegen", {false_literal}, uint32()); auto n_sort_to_indices = TreeExprBuilder::MakeFunction( - "sortArraysToIndices", {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check}, uint32()); + "sortArraysToIndices", + {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check, do_codegen}, uint32()); auto n_sort = TreeExprBuilder::MakeFunction( "standalone", {n_sort_to_indices}, uint32()); auto sortArrays_expr = TreeExprBuilder::MakeExpression(n_sort, f_res); @@ -61,7 +64,8 @@ TEST(TestArrowComputeSort, SortTestInPlaceNullsFirstAsc) { ///////////////////// Calculation ////////////////// std::shared_ptr sort_expr; arrow::compute::FunctionContext ctx; - ASSERT_NOT_OK(CreateCodeGenerator(ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); + ASSERT_NOT_OK(CreateCodeGenerator( + ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); std::shared_ptr input_batch; std::vector> input_batch_list; @@ -134,8 +138,11 @@ TEST(TestArrowComputeSort, SortTestInplaceNullsLastAsc) { "sort_nulls_order", {false_literal}, uint32()); auto NaN_check = TreeExprBuilder::MakeFunction( "NaN_check", {true_literal}, uint32()); + auto do_codegen = TreeExprBuilder::MakeFunction( + "codegen", {false_literal}, uint32()); auto n_sort_to_indices = TreeExprBuilder::MakeFunction( - "sortArraysToIndices", {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check}, uint32()); + "sortArraysToIndices", + {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check, do_codegen}, uint32()); auto n_sort = TreeExprBuilder::MakeFunction( "standalone", {n_sort_to_indices}, uint32()); auto sortArrays_expr = TreeExprBuilder::MakeExpression(n_sort, f_res); @@ -145,7 +152,8 @@ TEST(TestArrowComputeSort, SortTestInplaceNullsLastAsc) { ///////////////////// Calculation ////////////////// std::shared_ptr sort_expr; arrow::compute::FunctionContext ctx; - ASSERT_NOT_OK(CreateCodeGenerator(ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); + ASSERT_NOT_OK(CreateCodeGenerator( + ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); std::shared_ptr input_batch; std::vector> input_batch_list; @@ -218,8 +226,11 @@ TEST(TestArrowComputeSort, SortTestInplaceNullsFirstDesc) { "sort_nulls_order", {true_literal}, uint32()); auto NaN_check = TreeExprBuilder::MakeFunction( "NaN_check", {true_literal}, uint32()); + auto do_codegen = TreeExprBuilder::MakeFunction( + "codegen", {false_literal}, uint32()); auto n_sort_to_indices = TreeExprBuilder::MakeFunction( - "sortArraysToIndices", {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check}, uint32()); + "sortArraysToIndices", + {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check, do_codegen}, uint32()); auto n_sort = TreeExprBuilder::MakeFunction( "standalone", {n_sort_to_indices}, uint32()); auto sortArrays_expr = TreeExprBuilder::MakeExpression(n_sort, f_res); @@ -229,7 +240,8 @@ TEST(TestArrowComputeSort, SortTestInplaceNullsFirstDesc) { ///////////////////// Calculation ////////////////// std::shared_ptr sort_expr; arrow::compute::FunctionContext ctx; - ASSERT_NOT_OK(CreateCodeGenerator(ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); + ASSERT_NOT_OK(CreateCodeGenerator( + ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); std::shared_ptr input_batch; std::vector> input_batch_list; @@ -302,8 +314,11 @@ TEST(TestArrowComputeSort, SortTestInplaceNullsLastDesc) { "sort_nulls_order", {false_literal}, uint32()); auto NaN_check = TreeExprBuilder::MakeFunction( "NaN_check", {true_literal}, uint32()); + auto do_codegen = TreeExprBuilder::MakeFunction( + "codegen", {false_literal}, uint32()); auto n_sort_to_indices = TreeExprBuilder::MakeFunction( - "sortArraysToIndices", {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check}, uint32()); + "sortArraysToIndices", + {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check, do_codegen}, uint32()); auto n_sort = TreeExprBuilder::MakeFunction( "standalone", {n_sort_to_indices}, uint32()); auto sortArrays_expr = TreeExprBuilder::MakeExpression(n_sort, f_res); @@ -313,7 +328,8 @@ TEST(TestArrowComputeSort, SortTestInplaceNullsLastDesc) { ///////////////////// Calculation ////////////////// std::shared_ptr sort_expr; arrow::compute::FunctionContext ctx; - ASSERT_NOT_OK(CreateCodeGenerator(ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); + ASSERT_NOT_OK(CreateCodeGenerator( + ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); std::shared_ptr input_batch; std::vector> input_batch_list; @@ -386,8 +402,11 @@ TEST(TestArrowComputeSort, SortTestInplaceAsc) { "sort_nulls_order", {false_literal}, uint32()); auto NaN_check = TreeExprBuilder::MakeFunction( "NaN_check", {true_literal}, uint32()); + auto do_codegen = TreeExprBuilder::MakeFunction( + "codegen", {false_literal}, uint32()); auto n_sort_to_indices = TreeExprBuilder::MakeFunction( - "sortArraysToIndices", {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check}, uint32()); + "sortArraysToIndices", + {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check, do_codegen}, uint32()); auto n_sort = TreeExprBuilder::MakeFunction( "standalone", {n_sort_to_indices}, uint32()); auto sortArrays_expr = TreeExprBuilder::MakeExpression(n_sort, f_res); @@ -397,7 +416,8 @@ TEST(TestArrowComputeSort, SortTestInplaceAsc) { ///////////////////// Calculation ////////////////// std::shared_ptr sort_expr; arrow::compute::FunctionContext ctx; - ASSERT_NOT_OK(CreateCodeGenerator(ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); + ASSERT_NOT_OK(CreateCodeGenerator( + ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); std::shared_ptr input_batch; std::vector> input_batch_list; @@ -470,8 +490,11 @@ TEST(TestArrowComputeSort, SortTestInplaceDesc) { "sort_nulls_order", {false_literal}, uint32()); auto NaN_check = TreeExprBuilder::MakeFunction( "NaN_check", {true_literal}, uint32()); + auto do_codegen = TreeExprBuilder::MakeFunction( + "codegen", {false_literal}, uint32()); auto n_sort_to_indices = TreeExprBuilder::MakeFunction( - "sortArraysToIndices", {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check}, uint32()); + "sortArraysToIndices", + {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check, do_codegen}, uint32()); auto n_sort = TreeExprBuilder::MakeFunction( "standalone", {n_sort_to_indices}, uint32()); auto sortArrays_expr = TreeExprBuilder::MakeExpression(n_sort, f_res); @@ -481,7 +504,8 @@ TEST(TestArrowComputeSort, SortTestInplaceDesc) { ///////////////////// Calculation ////////////////// std::shared_ptr sort_expr; arrow::compute::FunctionContext ctx; - ASSERT_NOT_OK(CreateCodeGenerator(ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); + ASSERT_NOT_OK(CreateCodeGenerator( + ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); std::shared_ptr input_batch; std::vector> input_batch_list; @@ -553,8 +577,11 @@ TEST(TestArrowComputeSort, SortTestOnekeyNullsFirstAsc) { "sort_nulls_order", {TreeExprBuilder::MakeLiteral(true)}, uint32()); auto NaN_check = TreeExprBuilder::MakeFunction( "NaN_check", {TreeExprBuilder::MakeLiteral(true)}, uint32()); + auto do_codegen = TreeExprBuilder::MakeFunction( + "codegen", {TreeExprBuilder::MakeLiteral(false)}, uint32()); auto n_sort_to_indices = TreeExprBuilder::MakeFunction( - "sortArraysToIndices", {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check}, uint32()); + "sortArraysToIndices", + {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check, do_codegen}, uint32()); auto n_sort = TreeExprBuilder::MakeFunction( "standalone", {n_sort_to_indices}, uint32()); auto sortArrays_expr = TreeExprBuilder::MakeExpression(n_sort, f_res); @@ -564,8 +591,8 @@ TEST(TestArrowComputeSort, SortTestOnekeyNullsFirstAsc) { ///////////////////// Calculation ////////////////// std::shared_ptr sort_expr; arrow::compute::FunctionContext ctx; - ASSERT_NOT_OK(CreateCodeGenerator(ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); - + ASSERT_NOT_OK(CreateCodeGenerator( + ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); std::shared_ptr input_batch; std::vector> input_batch_list; std::vector> dummy_result_batches; @@ -644,8 +671,11 @@ TEST(TestArrowComputeSort, SortTestOnekeyNullsLastAsc) { "sort_nulls_order", {TreeExprBuilder::MakeLiteral(false)}, uint32()); auto NaN_check = TreeExprBuilder::MakeFunction( "NaN_check", {TreeExprBuilder::MakeLiteral(true)}, uint32()); + auto do_codegen = TreeExprBuilder::MakeFunction( + "codegen", {TreeExprBuilder::MakeLiteral(false)}, uint32()); auto n_sort_to_indices = TreeExprBuilder::MakeFunction( - "sortArraysToIndices", {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check}, uint32()); + "sortArraysToIndices", + {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check, do_codegen}, uint32()); auto n_sort = TreeExprBuilder::MakeFunction( "standalone", {n_sort_to_indices}, uint32()); auto sortArrays_expr = TreeExprBuilder::MakeExpression(n_sort, f_res); @@ -655,7 +685,8 @@ TEST(TestArrowComputeSort, SortTestOnekeyNullsLastAsc) { ///////////////////// Calculation ////////////////// std::shared_ptr sort_expr; arrow::compute::FunctionContext ctx; - ASSERT_NOT_OK(CreateCodeGenerator(ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); + ASSERT_NOT_OK(CreateCodeGenerator( + ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); std::shared_ptr input_batch; std::vector> input_batch_list; @@ -733,8 +764,11 @@ TEST(TestArrowComputeSort, SortTestOnekeyNullsFirstDesc) { "sort_nulls_order", {TreeExprBuilder::MakeLiteral(true)}, uint32()); auto NaN_check = TreeExprBuilder::MakeFunction( "NaN_check", {TreeExprBuilder::MakeLiteral(true)}, uint32()); + auto do_codegen = TreeExprBuilder::MakeFunction( + "codegen", {TreeExprBuilder::MakeLiteral(false)}, uint32()); auto n_sort_to_indices = TreeExprBuilder::MakeFunction( - "sortArraysToIndices", {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check}, uint32()); + "sortArraysToIndices", + {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check, do_codegen}, uint32()); auto n_sort = TreeExprBuilder::MakeFunction( "standalone", {n_sort_to_indices}, uint32()); auto sortArrays_expr = TreeExprBuilder::MakeExpression(n_sort, f_res); @@ -744,7 +778,8 @@ TEST(TestArrowComputeSort, SortTestOnekeyNullsFirstDesc) { ///////////////////// Calculation ////////////////// std::shared_ptr sort_expr; arrow::compute::FunctionContext ctx; - ASSERT_NOT_OK(CreateCodeGenerator(ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); + ASSERT_NOT_OK(CreateCodeGenerator( + ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); std::shared_ptr input_batch; std::vector> input_batch_list; @@ -822,8 +857,11 @@ TEST(TestArrowComputeSort, SortTestOnekeyNullsLastDesc) { "sort_nulls_order", {TreeExprBuilder::MakeLiteral(false)}, uint32()); auto NaN_check = TreeExprBuilder::MakeFunction( "NaN_check", {TreeExprBuilder::MakeLiteral(true)}, uint32()); + auto do_codegen = TreeExprBuilder::MakeFunction( + "codegen", {TreeExprBuilder::MakeLiteral(false)}, uint32()); auto n_sort_to_indices = TreeExprBuilder::MakeFunction( - "sortArraysToIndices", {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check}, uint32()); + "sortArraysToIndices", + {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check, do_codegen}, uint32()); auto n_sort = TreeExprBuilder::MakeFunction( "standalone", {n_sort_to_indices}, uint32()); auto sortArrays_expr = TreeExprBuilder::MakeExpression(n_sort, f_res); @@ -833,7 +871,8 @@ TEST(TestArrowComputeSort, SortTestOnekeyNullsLastDesc) { ///////////////////// Calculation ////////////////// std::shared_ptr sort_expr; arrow::compute::FunctionContext ctx; - ASSERT_NOT_OK(CreateCodeGenerator(ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); + ASSERT_NOT_OK(CreateCodeGenerator( + ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); std::shared_ptr input_batch; std::vector> input_batch_list; @@ -914,8 +953,11 @@ TEST(TestArrowComputeSort, SortTestOnekeyBooleanDesc) { "sort_nulls_order", {true_literal}, uint32()); auto NaN_check = TreeExprBuilder::MakeFunction( "NaN_check", {true_literal}, uint32()); + auto do_codegen = TreeExprBuilder::MakeFunction( + "codegen", {false_literal}, uint32()); auto n_sort_to_indices = TreeExprBuilder::MakeFunction( - "sortArraysToIndices", {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check}, uint32()); + "sortArraysToIndices", + {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check, do_codegen}, uint32()); auto n_sort = TreeExprBuilder::MakeFunction( "standalone", {n_sort_to_indices}, uint32()); auto sortArrays_expr = TreeExprBuilder::MakeExpression(n_sort, f_res); @@ -925,7 +967,8 @@ TEST(TestArrowComputeSort, SortTestOnekeyBooleanDesc) { ///////////////////// Calculation ////////////////// std::shared_ptr sort_expr; arrow::compute::FunctionContext ctx; - ASSERT_NOT_OK(CreateCodeGenerator(ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); + ASSERT_NOT_OK(CreateCodeGenerator( + ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); std::shared_ptr input_batch; std::vector> input_batch_list; @@ -1008,8 +1051,11 @@ TEST(TestArrowComputeSort, SortTestOneKeyStr) { "sort_nulls_order", {false_literal}, uint32()); auto NaN_check = TreeExprBuilder::MakeFunction( "NaN_check", {true_literal}, uint32()); + auto do_codegen = TreeExprBuilder::MakeFunction( + "codegen", {false_literal}, uint32()); auto n_sort_to_indices = TreeExprBuilder::MakeFunction( - "sortArraysToIndices", {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check}, uint32()); + "sortArraysToIndices", + {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check, do_codegen}, uint32()); auto n_sort = TreeExprBuilder::MakeFunction( "standalone", {n_sort_to_indices}, uint32()); auto sortArrays_expr = TreeExprBuilder::MakeExpression(n_sort, f_res); @@ -1019,7 +1065,8 @@ TEST(TestArrowComputeSort, SortTestOneKeyStr) { ///////////////////// Calculation ////////////////// std::shared_ptr sort_expr; arrow::compute::FunctionContext ctx; - ASSERT_NOT_OK(CreateCodeGenerator(ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); + ASSERT_NOT_OK(CreateCodeGenerator( + ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); std::shared_ptr input_batch; std::vector> input_batch_list; std::vector> dummy_result_batches; @@ -1095,8 +1142,11 @@ TEST(TestArrowComputeSort, SortTestOneKeyWithProjection) { "sort_nulls_order", {false_literal}, uint32()); auto NaN_check = TreeExprBuilder::MakeFunction( "NaN_check", {true_literal}, uint32()); + auto do_codegen = TreeExprBuilder::MakeFunction( + "codegen", {false_literal}, uint32()); auto n_sort_to_indices = TreeExprBuilder::MakeFunction( - "sortArraysToIndices", {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check}, uint32()); + "sortArraysToIndices", + {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check, do_codegen}, uint32()); auto n_sort = TreeExprBuilder::MakeFunction( "standalone", {n_sort_to_indices}, uint32()); auto sortArrays_expr = TreeExprBuilder::MakeExpression(n_sort, f_res); @@ -1106,7 +1156,8 @@ TEST(TestArrowComputeSort, SortTestOneKeyWithProjection) { ///////////////////// Calculation ////////////////// std::shared_ptr sort_expr; arrow::compute::FunctionContext ctx; - ASSERT_NOT_OK(CreateCodeGenerator(ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); + ASSERT_NOT_OK(CreateCodeGenerator( + ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); std::shared_ptr input_batch; std::vector> input_batch_list; std::vector> dummy_result_batches; @@ -1182,8 +1233,11 @@ TEST(TestArrowComputeSort, SortTestMultipleKeysNaN) { "sort_nulls_order", {false_literal, true_literal, true_literal}, uint32()); auto NaN_check = TreeExprBuilder::MakeFunction( "NaN_check", {true_literal}, uint32()); + auto do_codegen = TreeExprBuilder::MakeFunction( + "codegen", {true_literal}, uint32()); auto n_sort_to_indices = TreeExprBuilder::MakeFunction( - "sortArraysToIndices", {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check}, uint32()); + "sortArraysToIndices", + {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check, do_codegen}, uint32()); auto n_sort = TreeExprBuilder::MakeFunction( "standalone", {n_sort_to_indices}, uint32()); auto sortArrays_expr = TreeExprBuilder::MakeExpression(n_sort, f_res); @@ -1193,7 +1247,8 @@ TEST(TestArrowComputeSort, SortTestMultipleKeysNaN) { ///////////////////// Calculation ////////////////// std::shared_ptr sort_expr; arrow::compute::FunctionContext ctx; - ASSERT_NOT_OK(CreateCodeGenerator(ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); + ASSERT_NOT_OK(CreateCodeGenerator( + ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); std::shared_ptr input_batch; std::vector> input_batch_list; @@ -1306,7 +1361,8 @@ TEST(TestArrowComputeSort, SortTestMultipleKeysWithProjection) { "isnull", {arg_2}, arrow::boolean()); auto n_key_func = TreeExprBuilder::MakeFunction( - "key_function", {coalesce_0, isnull_0, coalesce_1, isnull_1, coalesce_2, isnull_2}, uint32()); + "key_function", + {coalesce_0, isnull_0, coalesce_1, isnull_1, coalesce_2, isnull_2}, uint32()); auto n_key_field = TreeExprBuilder::MakeFunction( "key_field", {arg_0, arg_0, arg_1, arg_1, arg_2, arg_2}, uint32()); auto n_dir = TreeExprBuilder::MakeFunction( @@ -1317,8 +1373,11 @@ TEST(TestArrowComputeSort, SortTestMultipleKeysWithProjection) { true_literal, true_literal}, uint32()); auto NaN_check = TreeExprBuilder::MakeFunction( "NaN_check", {true_literal}, uint32()); + auto do_codegen = TreeExprBuilder::MakeFunction( + "codegen", {true_literal}, uint32()); auto n_sort_to_indices = TreeExprBuilder::MakeFunction( - "sortArraysToIndices", {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check}, uint32()); + "sortArraysToIndices", + {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check, do_codegen}, uint32()); auto n_sort = TreeExprBuilder::MakeFunction( "standalone", {n_sort_to_indices}, uint32()); auto sortArrays_expr = TreeExprBuilder::MakeExpression(n_sort, f_res); @@ -1329,7 +1388,8 @@ TEST(TestArrowComputeSort, SortTestMultipleKeysWithProjection) { ///////////////////// Calculation ////////////////// std::shared_ptr sort_expr; arrow::compute::FunctionContext ctx; - ASSERT_NOT_OK(CreateCodeGenerator(ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); + ASSERT_NOT_OK(CreateCodeGenerator( + ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); std::shared_ptr input_batch; std::vector> input_batch_list; @@ -1401,6 +1461,370 @@ TEST(TestArrowComputeSort, SortTestMultipleKeysWithProjection) { } } +TEST(TestArrowComputeSort, SortTestMultipleKeysWithoutCodegen) { + ////////////////////// prepare expr_vector /////////////////////// + auto f0 = field("f0", float32()); + auto f1 = field("f1", utf8()); + auto f2 = field("f2", uint32()); + auto f3 = field("f3", float64()); + auto arg_0 = TreeExprBuilder::MakeField(f0); + auto arg_1 = TreeExprBuilder::MakeField(f1); + auto arg_2 = TreeExprBuilder::MakeField(f2); + auto true_literal = TreeExprBuilder::MakeLiteral(true); + auto false_literal = TreeExprBuilder::MakeLiteral(false); + auto f_res = field("res", uint32()); + auto indices_type = std::make_shared(16); + auto f_indices = field("indices", indices_type); + + auto n_key_func = TreeExprBuilder::MakeFunction( + "key_function", {arg_0, arg_1, arg_2}, uint32()); + auto n_key_field = TreeExprBuilder::MakeFunction( + "key_field", {arg_0, arg_1, arg_2}, uint32()); + auto n_dir = TreeExprBuilder::MakeFunction( + "sort_directions", {true_literal, false_literal, true_literal}, uint32()); + auto n_nulls_order = TreeExprBuilder::MakeFunction( + "sort_nulls_order", {false_literal, true_literal, true_literal}, uint32()); + auto NaN_check = TreeExprBuilder::MakeFunction( + "NaN_check", {true_literal}, uint32()); + auto do_codegen = TreeExprBuilder::MakeFunction( + "codegen", {false_literal}, uint32()); + auto n_sort_to_indices = TreeExprBuilder::MakeFunction( + "sortArraysToIndices", + {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check, do_codegen}, uint32()); + auto n_sort = TreeExprBuilder::MakeFunction( + "standalone", {n_sort_to_indices}, uint32()); + auto sortArrays_expr = TreeExprBuilder::MakeExpression(n_sort, f_res); + + auto sch = arrow::schema({f0, f1, f2, f3}); + std::vector> ret_types = {f0, f1, f2, f3}; + ///////////////////// Calculation ////////////////// + std::shared_ptr sort_expr; + arrow::compute::FunctionContext ctx; + ASSERT_NOT_OK(CreateCodeGenerator( + ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); + + std::shared_ptr input_batch; + std::vector> input_batch_list; + std::vector> dummy_result_batches; + + std::vector input_data_string = {"[8, 9, 4, 50, 52, 32, 11]", + R"([null, "a", "a", "b", "b","b", "b"])", + "[11, 3, 5, 51, null, 33, 12]", + "[1, 3, 5, 10, null, 13, 2]"}; + MakeInputBatch(input_data_string, sch, &input_batch); + input_batch_list.push_back(input_batch); + + std::vector input_data_string_2 = {"[1, 14, 6, 42, 6, null, 2]", + R"(["a", "a", null, "b", "b", "a", "b"])", + "[2, null, 44, 43, 7, 34, 3]", + "[9, 7, 5, 1, 5, null, 17]"}; + MakeInputBatch(input_data_string_2, sch, &input_batch); + input_batch_list.push_back(input_batch); + + std::vector input_data_string_3 = {"[3, 64, 8, 7, 9, 8, 12]", + R"(["a", "a", "b", "b", "b","b", "b"])", + "[4, 65, 16, 8, 10, 20, 34]", + "[8, 6, 2, 3, 10, 12, 15]"}; + MakeInputBatch(input_data_string_3, sch, &input_batch); + input_batch_list.push_back(input_batch); + + std::vector input_data_string_4 = {"[23, 17, 41, 18, 20, 35, 30]", + R"(["a", "a", "a", "b", "b","b", "b"])", + "[24, 18, 42, 15, 21, 36, 31]", + "[15, 16, 2, 51, null, 33, 12]"}; + MakeInputBatch(input_data_string_4, sch, &input_batch); + input_batch_list.push_back(input_batch); + + std::vector input_data_string_5 = {"[37, null, 22, 13, 8, 59, 21]", + R"(["a", "b", "a", "b", "a","b", "b"])", + "[38, 67, 23, 14, null, 60, 22]", + "[16, 17, 5, 15, 9, null, 19]"}; + MakeInputBatch(input_data_string_5, sch, &input_batch); + input_batch_list.push_back(input_batch); + + ////////////////////////////////// calculation /////////////////////////////////// + std::shared_ptr expected_result; + std::vector expected_result_string = { + "[1, 2, 3, 4, 6, 6, 7, 8, 8, 8, 8, 9, 9, 11, 12, 13, 14, 17, 18, 20, 21, " + "22, 23, 30, 32, 35, 37, 41, 42, 50, 52, 59, 64, null, null]", + R"(["a","b","a","a",null,"b","b",null,"b","b","a","b","a","b","b","b","a","a","b","b","b","a","a","b","b","b","a","a","b","b","b","b","a","b","a"])", + "[2, 3, 4, 5, 44, 7, 8, 11, 16, 20, null, 10, 3, 12, 34, 14, null, 18, 15, 21, 22, " + "23, 24, 31, 33, 36, 38, 42, 43, 51, null, 60, 65, 67, 34]", + "[9, 17, 8, 5, 5, 5, 3, 1, 2, 12, 9, 10, 3, 2, 15, 15, 7, 16, 51, null, 19, 5, " + "15, 12, 13, 33, 16, 2, 1, 10, null, null, 6, 17, null]"}; + + MakeInputBatch(expected_result_string, sch, &expected_result); + + for (auto batch : input_batch_list) { + ASSERT_NOT_OK(sort_expr->evaluate(batch, &dummy_result_batches)); + } + std::shared_ptr> sort_result_iterator; + std::shared_ptr sort_result_iterator_base; + ASSERT_NOT_OK(sort_expr->finish(&sort_result_iterator_base)); + sort_result_iterator = std::dynamic_pointer_cast>( + sort_result_iterator_base); + + std::shared_ptr dummy_result_batch; + std::shared_ptr result_batch; + + if (sort_result_iterator->HasNext()) { + ASSERT_NOT_OK(sort_result_iterator->Next(&result_batch)); + ASSERT_NOT_OK(Equals(*expected_result.get(), *result_batch.get())); + } +} + +TEST(TestArrowComputeSort, SortTestMultipleKeysWithoutCodegenWithProjection) { + ////////////////////// prepare expr_vector /////////////////////// + auto f0 = field("f0", uint32()); + auto f1 = field("f1", utf8()); + auto f2 = field("f2", uint32()); + auto f3 = field("f3", uint32()); + auto arg_0 = TreeExprBuilder::MakeField(f0); + auto arg_1 = TreeExprBuilder::MakeField(f1); + auto arg_2 = TreeExprBuilder::MakeField(f2); + auto true_literal = TreeExprBuilder::MakeLiteral(true); + auto false_literal = TreeExprBuilder::MakeLiteral(false); + auto f_res = field("res", uint32()); + auto f_bool = field("res", arrow::boolean()); + auto indices_type = std::make_shared(16); + auto f_indices = field("indices", indices_type); + + auto uint32_node = TreeExprBuilder::MakeLiteral((uint32_t)0); + auto str_node = TreeExprBuilder::MakeStringLiteral(""); + + auto isnotnull_0 = TreeExprBuilder::MakeFunction( + "isnotnull", {TreeExprBuilder::MakeField(f0)}, arrow::boolean()); + auto coalesce_0 = TreeExprBuilder::MakeIf( + isnotnull_0, TreeExprBuilder::MakeField(f0), uint32_node, uint32()); + auto isnull_0 = TreeExprBuilder::MakeFunction( + "isnull", {arg_0}, arrow::boolean()); + + auto isnotnull_1 = TreeExprBuilder::MakeFunction( + "isnotnull", {TreeExprBuilder::MakeField(f1)}, arrow::boolean()); + auto coalesce_1 = TreeExprBuilder::MakeIf( + isnotnull_1, TreeExprBuilder::MakeField(f1), str_node, utf8()); + auto isnull_1 = TreeExprBuilder::MakeFunction( + "isnull", {arg_1}, arrow::boolean()); + + auto isnotnull_2 = TreeExprBuilder::MakeFunction( + "isnotnull", {TreeExprBuilder::MakeField(f2)}, arrow::boolean()); + auto coalesce_2 = TreeExprBuilder::MakeIf( + isnotnull_2, TreeExprBuilder::MakeField(f2), uint32_node, uint32()); + auto isnull_2 = TreeExprBuilder::MakeFunction( + "isnull", {arg_2}, arrow::boolean()); + + auto n_key_func = TreeExprBuilder::MakeFunction( + "key_function", + {coalesce_0, isnull_0, coalesce_1, isnull_1, coalesce_2, isnull_2}, uint32()); + auto n_key_field = TreeExprBuilder::MakeFunction( + "key_field", {arg_0, arg_0, arg_1, arg_1, arg_2, arg_2}, uint32()); + auto n_dir = TreeExprBuilder::MakeFunction( + "sort_directions", {true_literal, true_literal, false_literal, false_literal, + true_literal, true_literal,}, uint32()); + auto n_nulls_order = TreeExprBuilder::MakeFunction( + "sort_nulls_order", {false_literal, false_literal, true_literal, true_literal, + true_literal, true_literal}, uint32()); + auto NaN_check = TreeExprBuilder::MakeFunction( + "NaN_check", {true_literal}, uint32()); + auto do_codegen = TreeExprBuilder::MakeFunction( + "codegen", {false_literal}, uint32()); + auto n_sort_to_indices = TreeExprBuilder::MakeFunction( + "sortArraysToIndices", + {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check, do_codegen}, uint32()); + auto n_sort = TreeExprBuilder::MakeFunction( + "standalone", {n_sort_to_indices}, uint32()); + auto sortArrays_expr = TreeExprBuilder::MakeExpression(n_sort, f_res); + + auto sch = arrow::schema({f0, f1, f2, f3}); + std::vector> ret_types = {f0, f1, f2, f3}; + auto ret_schema = arrow::schema(ret_types); + ///////////////////// Calculation ////////////////// + std::shared_ptr sort_expr; + arrow::compute::FunctionContext ctx; + ASSERT_NOT_OK(CreateCodeGenerator( + ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); + + std::shared_ptr input_batch; + std::vector> input_batch_list; + std::vector> dummy_result_batches; + + std::vector input_data_string = {"[8, 8, 4, 50, 52, 32, 11]", + R"([null, "b", "a", "b", "b","b", "b"])", + "[11, 10, 5, 51, null, 33, 12]", + "[1, 3, 5, 10, null, 13, 2]"}; + MakeInputBatch(input_data_string, sch, &input_batch); + input_batch_list.push_back(input_batch); + + std::vector input_data_string_2 = {"[1, 14, 8, 42, 6, null, 2]", + R"(["a", "a", null, "b", "b","b", "b"])", + "[2, null, 44, 43, 7, 34, 3]", + "[9, 7, 5, 1, 5, null, 17]"}; + MakeInputBatch(input_data_string_2, sch, &input_batch); + input_batch_list.push_back(input_batch); + + std::vector input_data_string_3 = {"[3, 64, 8, 7, 9, 8, 33]", + R"(["a", "a", "a", "b", "b","b", "b"])", + "[4, 65, 16, 8, 10, 20, 34]", + "[8, 6, 2, 3, 10, 12, 15]"}; + MakeInputBatch(input_data_string_3, sch, &input_batch); + input_batch_list.push_back(input_batch); + + std::vector input_data_string_4 = {"[23, 17, 41, 18, 20, 35, 30]", + R"(["a", "a", "a", "b", "b","b", "b"])", + "[24, 18, 42, 19, 21, 36, 31]", + "[15, 16, 2, 51, null, 33, 12]"}; + MakeInputBatch(input_data_string_4, sch, &input_batch); + input_batch_list.push_back(input_batch); + + std::vector input_data_string_5 = {"[37, null, 22, 13, 8, 59, 21]", + R"(["a", "a", "a", "b", "b","b", "b"])", + "[38, 67, 23, 14, null, 60, 22]", + "[16, 17, 5, 15, 9, null, 19]"}; + MakeInputBatch(input_data_string_5, sch, &input_batch); + input_batch_list.push_back(input_batch); + + ////////////////////////////////// calculation /////////////////////////////////// + std::shared_ptr expected_result; + std::vector expected_result_string = { + "[null, null, 1, 2, 3, 4, 6, 7, 8, 8, 8, 8, 8, 8, 9, 11, 13, 14, 17, 18, 20, 21, " + "22, 23, 30, 32, 33, 35, 37, 41, 42, 50, 52, 59, 64]", + R"(["b","a","a","b","a","a","b","b","b","b","b","a", null, null,"b","b","b","a","a","b","b","b","a","a","b","b","b","b","a","a","b","b","b","b","a"])", + "[34, 67, 2, 3, 4, 5, 7, 8, null, 10, 20, 16, 11, 44, 10, 12, 14, null, 18, 19, 21, 22, 23, " + "24, 31, 33, 34, 36, 38, 42, 43, 51, null, 60, 65]", + "[null, 17, 9, 17, 8, 5, 5, 3, 9, 3, 12, 2, 1, 5, 10, 2, 15, 7, 16, 51, null, 19, 5, " + "15, 12, 13, 15, 33, 16, 2, 1, 10, null, null, 6]"}; + + MakeInputBatch(expected_result_string, ret_schema, &expected_result); + + for (auto batch : input_batch_list) { + ASSERT_NOT_OK(sort_expr->evaluate(batch, &dummy_result_batches)); + } + std::shared_ptr> sort_result_iterator; + std::shared_ptr sort_result_iterator_base; + ASSERT_NOT_OK(sort_expr->finish(&sort_result_iterator_base)); + sort_result_iterator = std::dynamic_pointer_cast>( + sort_result_iterator_base); + + std::shared_ptr dummy_result_batch; + std::shared_ptr result_batch; + + if (sort_result_iterator->HasNext()) { + ASSERT_NOT_OK(sort_result_iterator->Next(&result_batch)); + ASSERT_NOT_OK(Equals(*expected_result.get(), *result_batch.get())); + } +} + +TEST(TestArrowComputeSort, SortTestMultipleKeysNaNWithoutCodegen) { + ////////////////////// prepare expr_vector /////////////////////// + auto f0 = field("f0", float64()); + auto f1 = field("f1", utf8()); + auto f2 = field("f2", float64()); + auto f3 = field("f3", uint32()); + auto arg_0 = TreeExprBuilder::MakeField(f0); + auto arg_1 = TreeExprBuilder::MakeField(f1); + auto arg_2 = TreeExprBuilder::MakeField(f2); + auto true_literal = TreeExprBuilder::MakeLiteral(true); + auto false_literal = TreeExprBuilder::MakeLiteral(false); + auto f_res = field("res", uint32()); + auto indices_type = std::make_shared(16); + auto f_indices = field("indices", indices_type); + + auto n_key_func = TreeExprBuilder::MakeFunction( + "key_function", {arg_0, arg_1, arg_2}, uint32()); + auto n_key_field = TreeExprBuilder::MakeFunction( + "key_field", {arg_0, arg_1, arg_2}, uint32()); + auto n_dir = TreeExprBuilder::MakeFunction( + "sort_directions", {true_literal, false_literal, true_literal}, uint32()); + auto n_nulls_order = TreeExprBuilder::MakeFunction( + "sort_nulls_order", {false_literal, true_literal, true_literal}, uint32()); + auto NaN_check = TreeExprBuilder::MakeFunction( + "NaN_check", {true_literal}, uint32()); + auto do_codegen = TreeExprBuilder::MakeFunction( + "codegen", {false_literal}, uint32()); + auto n_sort_to_indices = TreeExprBuilder::MakeFunction( + "sortArraysToIndices", + {n_key_func, n_key_field, n_dir, n_nulls_order, NaN_check, do_codegen}, uint32()); + auto n_sort = TreeExprBuilder::MakeFunction( + "standalone", {n_sort_to_indices}, uint32()); + auto sortArrays_expr = TreeExprBuilder::MakeExpression(n_sort, f_res); + + auto sch = arrow::schema({f0, f1, f2, f3}); + std::vector> ret_types = {f0, f1, f2, f3}; + ///////////////////// Calculation ////////////////// + std::shared_ptr sort_expr; + arrow::compute::FunctionContext ctx; + ASSERT_NOT_OK(CreateCodeGenerator( + ctx.memory_pool(), sch, {sortArrays_expr}, ret_types, &sort_expr, true)); + + std::shared_ptr input_batch; + std::vector> input_batch_list; + std::vector> dummy_result_batches; + + std::vector input_data_string = {"[8, NaN, 4, 50, 52, 32, 11]", + R"([null, "a", "a", "b", "b","b", "b"])", + "[11, NaN, 5, 51, null, 33, 12]", + "[1, 3, 5, 10, null, 13, 2]"}; + MakeInputBatch(input_data_string, sch, &input_batch); + input_batch_list.push_back(input_batch); + + std::vector input_data_string_2 = {"[1, 14, NaN, 42, 6, null, 2]", + R"(["a", "a", null, "b", "b", "a", "b"])", + "[2, null, 44, 43, 7, 34, 3]", + "[9, 7, 5, 1, 5, null, 17]"}; + MakeInputBatch(input_data_string_2, sch, &input_batch); + input_batch_list.push_back(input_batch); + + std::vector input_data_string_3 = {"[3, 64, 8, 7, 9, 8, NaN]", + R"(["a", "a", "b", "b", "b","b", "b"])", + "[4, 65, 16, 8, 10, 20, 34]", + "[8, 6, 2, 3, 10, 12, 15]"}; + MakeInputBatch(input_data_string_3, sch, &input_batch); + input_batch_list.push_back(input_batch); + + std::vector input_data_string_4 = {"[23, 17, 41, 18, 20, 35, 30]", + R"(["a", "a", "a", "b", "b","b", "b"])", + "[24, 18, 42, NaN, 21, 36, 31]", + "[15, 16, 2, 51, null, 33, 12]"}; + MakeInputBatch(input_data_string_4, sch, &input_batch); + input_batch_list.push_back(input_batch); + + std::vector input_data_string_5 = {"[37, null, 22, 13, 8, 59, 21]", + R"(["a", "b", "a", "b", "b","b", "b"])", + "[38, 67, 23, 14, null, 60, 22]", + "[16, 17, 5, 15, 9, null, 19]"}; + MakeInputBatch(input_data_string_5, sch, &input_batch); + input_batch_list.push_back(input_batch); + + ////////////////////////////////// calculation /////////////////////////////////// + std::shared_ptr expected_result; + std::vector expected_result_string = { + "[1, 2, 3, 4, 6, 7, 8, 8, 8, 8, 9, 11, 13, 14, 17, 18, 20, 21, " + "22, 23, 30, 32, 35, 37, 41, 42, 50, 52, 59, 64, NaN, NaN, NaN, null, null]", + R"(["a","b","a","a","b","b", null,"b","b","b","b","b","b","a","a","b","b","b","a","a","b","b","b","a","a","b","b","b","b","a",null,"b","a","b","a"])", + "[2, 3, 4, 5, 7, 8, 11, null, 16, 20, 10, 12, 14, null, 18, NaN, 21, 22, 23, " + "24, 31, 33, 36, 38, 42, 43, 51, null, 60, 65, 44, 34, NaN, 67, 34]", + "[9, 17, 8, 5, 5, 3, 1, 9, 2, 12, 10, 2, 15, 7, 16, 51, null, 19, 5, " + "15, 12, 13, 33, 16, 2, 1, 10, null, null, 6, 5, 15, 3, 17, null]"}; + + MakeInputBatch(expected_result_string, sch, &expected_result); + + for (auto batch : input_batch_list) { + ASSERT_NOT_OK(sort_expr->evaluate(batch, &dummy_result_batches)); + } + std::shared_ptr> sort_result_iterator; + std::shared_ptr sort_result_iterator_base; + ASSERT_NOT_OK(sort_expr->finish(&sort_result_iterator_base)); + sort_result_iterator = std::dynamic_pointer_cast>( + sort_result_iterator_base); + + std::shared_ptr dummy_result_batch; + std::shared_ptr result_batch; + + if (sort_result_iterator->HasNext()) { + ASSERT_NOT_OK(sort_result_iterator->Next(&result_batch)); + ASSERT_NOT_OK(Equals(*expected_result.get(), *result_batch.get())); + } +} } // namespace codegen } // namespace sparkcolumnarplugin diff --git a/cpp/src/third_party/function.h b/cpp/src/third_party/function.h new file mode 100644 index 000000000..291d46ab3 --- /dev/null +++ b/cpp/src/third_party/function.h @@ -0,0 +1,630 @@ +/* +This is free and unencumbered software released into the public domain. + +Anyone is free to copy, modify, publish, use, compile, sell, or +distribute this software, either in source code form or as a compiled +binary, for any purpose, commercial or non-commercial, and by any +means. + +In jurisdictions that recognize copyright laws, the author or authors +of this software dedicate any and all copyright interest in the +software to the public domain. We make this dedication for the benefit +of the public at large and to the detriment of our heirs and +successors. We intend this dedication to be an overt act of +relinquishment in perpetuity of all present and future rights to this +software under copyright law. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR +OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. + +For more information, please refer to + */ +// despite that it would be nice if you give credit to Malte Skarupke + + +#pragma once +#include +#include +#include +#include +#include +#include + +#ifdef _MSC_VER +#define FUNC_NOEXCEPT +#define FUNC_TEMPLATE_NOEXCEPT(FUNCTOR, ALLOCATOR) +#define FUNC_CONSTEXPR const +#else +#define FUNC_NOEXCEPT noexcept +#define FUNC_TEMPLATE_NOEXCEPT(FUNCTOR, ALLOCATOR) noexcept(detail::is_inplace_allocated::value) +#define FUNC_CONSTEXPR constexpr +#endif +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif + +#define FUNC_MOVE(value) static_cast::type &&>(value) +#define FUNC_FORWARD(type, value) static_cast(value) + +namespace func +{ +#ifndef FUNC_NO_EXCEPTIONS + struct bad_function_call : std::exception + { + const char * what() const FUNC_NOEXCEPT override + { + return "Bad function call"; + } + }; +#endif + +template +struct force_function_heap_allocation + : std::false_type +{ +}; + +template +class function; + +namespace detail +{ + struct manager_storage_type; + struct function_manager; + struct functor_padding + { + protected: + size_t padding_first; + size_t padding_second; + }; + + struct empty_struct + { + }; + +# ifndef FUNC_NO_EXCEPTIONS + template + Result empty_call(const functor_padding &, Arguments...) + { + throw bad_function_call(); + } +# endif + + template + struct is_inplace_allocated + { + static const bool value + // so that it fits + = sizeof(T) <= sizeof(functor_padding) + // so that it will be aligned + && std::alignment_of::value % std::alignment_of::value == 0 + // so that we can offer noexcept move + && std::is_nothrow_move_constructible::value + // so that the user can override it + && !force_function_heap_allocation::value; + }; + + template + T to_functor(T && func) + { + return FUNC_FORWARD(T, func); + } + template + auto to_functor(Result (Class::*func)(Arguments...)) -> decltype(std::mem_fn(func)) + { + return std::mem_fn(func); + } + template + auto to_functor(Result (Class::*func)(Arguments...) const) -> decltype(std::mem_fn(func)) + { + return std::mem_fn(func); + } + + template + struct functor_type + { + typedef decltype(to_functor(std::declval())) type; + }; + + template + bool is_null(const T &) + { + return false; + } + template + bool is_null(Result (* const & function_pointer)(Arguments...)) + { + return function_pointer == nullptr; + } + template + bool is_null(Result (Class::* const & function_pointer)(Arguments...)) + { + return function_pointer == nullptr; + } + template + bool is_null(Result (Class::* const & function_pointer)(Arguments...) const) + { + return function_pointer == nullptr; + } + + template + struct is_valid_function_argument + { + static const bool value = false; + }; + + template + struct is_valid_function_argument, Result (Arguments...)> + { + static const bool value = false; + }; + + template + struct is_valid_function_argument + { +# ifdef _MSC_VER + // as of january 2013 visual studio doesn't support the SFINAE below + static const bool value = true; +# else + template + static decltype(to_functor(std::declval())(std::declval()...)) check(U *); + template + static empty_struct check(...); + + static const bool value = std::is_convertible(nullptr)), Result>::value; +# endif + }; + + typedef const function_manager * manager_type; + + struct manager_storage_type + { + template + Allocator & get_allocator() FUNC_NOEXCEPT + { + return reinterpret_cast(manager); + } + template + const Allocator & get_allocator() const FUNC_NOEXCEPT + { + return reinterpret_cast(manager); + } + + functor_padding functor; + manager_type manager; + }; + + template + struct function_manager_inplace_specialization + { + template + static Result call(const functor_padding & storage, Arguments... arguments) + { + // do not call get_functor_ref because I want this function to be fast + // in debug when nothing gets inlined + return const_cast(reinterpret_cast(storage))(FUNC_FORWARD(Arguments, arguments)...); + } + + static void store_functor(manager_storage_type & storage, T to_store) + { + new (&get_functor_ref(storage)) T(FUNC_FORWARD(T, to_store)); + } + static void move_functor(manager_storage_type & lhs, manager_storage_type && rhs) FUNC_NOEXCEPT + { + new (&get_functor_ref(lhs)) T(FUNC_MOVE(get_functor_ref(rhs))); + } + static void destroy_functor(Allocator &, manager_storage_type & storage) FUNC_NOEXCEPT + { + get_functor_ref(storage).~T(); + } + static T & get_functor_ref(const manager_storage_type & storage) FUNC_NOEXCEPT + { + return const_cast(reinterpret_cast(storage.functor)); + } + }; + template + struct function_manager_inplace_specialization::value>::type> + { + template + static Result call(const functor_padding & storage, Arguments... arguments) + { + // do not call get_functor_ptr_ref because I want this function to be fast + // in debug when nothing gets inlined + return (*reinterpret_cast::pointer &>(storage))(FUNC_FORWARD(Arguments, arguments)...); + } + + static void store_functor(manager_storage_type & self, T to_store) + { + Allocator & allocator = self.get_allocator();; + static_assert(sizeof(typename std::allocator_traits::pointer) <= sizeof(self.functor), "The allocator's pointer type is too big"); + typename std::allocator_traits::pointer * ptr = new (&get_functor_ptr_ref(self)) typename std::allocator_traits::pointer(std::allocator_traits::allocate(allocator, 1)); + std::allocator_traits::construct(allocator, *ptr, FUNC_FORWARD(T, to_store)); + } + static void move_functor(manager_storage_type & lhs, manager_storage_type && rhs) FUNC_NOEXCEPT + { + static_assert(std::is_nothrow_move_constructible::pointer>::value, "we can't offer a noexcept swap if the pointer type is not nothrow move constructible"); + new (&get_functor_ptr_ref(lhs)) typename std::allocator_traits::pointer(FUNC_MOVE(get_functor_ptr_ref(rhs))); + // this next assignment makes the destroy function easier + get_functor_ptr_ref(rhs) = nullptr; + } + static void destroy_functor(Allocator & allocator, manager_storage_type & storage) FUNC_NOEXCEPT + { + typename std::allocator_traits::pointer & pointer = get_functor_ptr_ref(storage); + if (!pointer) return; + std::allocator_traits::destroy(allocator, pointer); + std::allocator_traits::deallocate(allocator, pointer, 1); + } + static T & get_functor_ref(const manager_storage_type & storage) FUNC_NOEXCEPT + { + return *get_functor_ptr_ref(storage); + } + static typename std::allocator_traits::pointer & get_functor_ptr_ref(manager_storage_type & storage) FUNC_NOEXCEPT + { + return reinterpret_cast::pointer &>(storage.functor); + } + static const typename std::allocator_traits::pointer & get_functor_ptr_ref(const manager_storage_type & storage) FUNC_NOEXCEPT + { + return reinterpret_cast::pointer &>(storage.functor); + } + }; + + template + static const function_manager & get_default_manager(); + + template + static void create_manager(manager_storage_type & storage, Allocator && allocator) + { + new (&storage.get_allocator()) Allocator(FUNC_MOVE(allocator)); + storage.manager = &get_default_manager(); + } + + // this struct acts as a vtable. it is an optimization to prevent + // code-bloat from rtti. see the documentation of boost::function + struct function_manager + { + template + inline static FUNC_CONSTEXPR function_manager create_default_manager() + { +# ifdef _MSC_VER + function_manager result = +# else + return function_manager +# endif + { + &templated_call_move_and_destroy, + &templated_call_copy, + &templated_call_copy_functor_only, + &templated_call_destroy, +# ifndef FUNC_NO_RTTI + &templated_call_type_id, + &templated_call_target +# endif + }; +# ifdef _MSC_VER + return result; +# endif + } + + void (* const call_move_and_destroy)(manager_storage_type & lhs, manager_storage_type && rhs); + void (* const call_copy)(manager_storage_type & lhs, const manager_storage_type & rhs); + void (* const call_copy_functor_only)(manager_storage_type & lhs, const manager_storage_type & rhs); + void (* const call_destroy)(manager_storage_type & manager); +# ifndef FUNC_NO_RTTI + const std::type_info & (* const call_type_id)(); + void * (* const call_target)(const manager_storage_type & manager, const std::type_info & type); +# endif + + template + static void templated_call_move_and_destroy(manager_storage_type & lhs, manager_storage_type && rhs) + { + typedef function_manager_inplace_specialization specialization; + specialization::move_functor(lhs, FUNC_MOVE(rhs)); + specialization::destroy_functor(rhs.get_allocator(), rhs); + create_manager(lhs, FUNC_MOVE(rhs.get_allocator())); + rhs.get_allocator().~Allocator(); + } + template + static void templated_call_copy(manager_storage_type & lhs, const manager_storage_type & rhs) + { + typedef function_manager_inplace_specialization specialization; + create_manager(lhs, Allocator(rhs.get_allocator())); + specialization::store_functor(lhs, specialization::get_functor_ref(rhs)); + } + template + static void templated_call_destroy(manager_storage_type & self) + { + typedef function_manager_inplace_specialization specialization; + specialization::destroy_functor(self.get_allocator(), self); + self.get_allocator().~Allocator(); + } + template + static void templated_call_copy_functor_only(manager_storage_type & lhs, const manager_storage_type & rhs) + { + typedef function_manager_inplace_specialization specialization; + specialization::store_functor(lhs, specialization::get_functor_ref(rhs)); + } +# ifndef FUNC_NO_RTTI + template + static const std::type_info & templated_call_type_id() + { + return typeid(T); + } + template + static void * templated_call_target(const manager_storage_type & self, const std::type_info & type) + { + typedef function_manager_inplace_specialization specialization; + if (type == typeid(T)) + return &specialization::get_functor_ref(self); + else + return nullptr; + } +# endif + }; + template + inline static const function_manager & get_default_manager() + { + static FUNC_CONSTEXPR function_manager default_manager = function_manager::create_default_manager(); + return default_manager; + } + + template + struct typedeffer + { + typedef Result result_type; + }; + template + struct typedeffer + { + typedef Result result_type; + typedef Argument argument_type; + }; + template + struct typedeffer + { + typedef Result result_type; + typedef First_Argument first_argument_type; + typedef Second_Argument second_argument_type; + }; +} + +template +class function + : public detail::typedeffer +{ +public: + function() FUNC_NOEXCEPT + { + initialize_empty(); + } + function(std::nullptr_t) FUNC_NOEXCEPT + { + initialize_empty(); + } + function(function && other) FUNC_NOEXCEPT + { + initialize_empty(); + swap(other); + } + function(const function & other) + : call(other.call) + { + other.manager_storage.manager->call_copy(manager_storage, other.manager_storage); + } + template + function(T functor, + typename std::enable_if::value, detail::empty_struct>::type = detail::empty_struct()) FUNC_TEMPLATE_NOEXCEPT(T, std::allocator::type>) + { + if (detail::is_null(functor)) + { + initialize_empty(); + } + else + { + typedef typename detail::functor_type::type functor_type; + initialize(detail::to_functor(FUNC_FORWARD(T, functor)), std::allocator()); + } + } + template + function(std::allocator_arg_t, const Allocator &) + { + // ignore the allocator because I don't allocate + initialize_empty(); + } + template + function(std::allocator_arg_t, const Allocator &, std::nullptr_t) + { + // ignore the allocator because I don't allocate + initialize_empty(); + } + template + function(std::allocator_arg_t, const Allocator & allocator, T functor, + typename std::enable_if::value, detail::empty_struct>::type = detail::empty_struct()) + FUNC_TEMPLATE_NOEXCEPT(T, Allocator) + { + if (detail::is_null(functor)) + { + initialize_empty(); + } + else + { + initialize(detail::to_functor(FUNC_FORWARD(T, functor)), Allocator(allocator)); + } + } + template + function(std::allocator_arg_t, const Allocator & allocator, const function & other) + : call(other.call) + { + typedef typename std::allocator_traits::template rebind_alloc MyAllocator; + + // first try to see if the allocator matches the target type + detail::manager_type manager_for_allocator = &detail::get_default_manager::value_type, Allocator>(); + if (other.manager_storage.manager == manager_for_allocator) + { + detail::create_manager::value_type, Allocator>(manager_storage, Allocator(allocator)); + manager_for_allocator->call_copy_functor_only(manager_storage, other.manager_storage); + } + // if it does not, try to see if the target contains my type. this + // breaks the recursion of the last case. otherwise repeated copies + // would allocate more and more memory + else + { + detail::manager_type manager_for_function = &detail::get_default_manager(); + if (other.manager_storage.manager == manager_for_function) + { + detail::create_manager(manager_storage, MyAllocator(allocator)); + manager_for_function->call_copy_functor_only(manager_storage, other.manager_storage); + } + else + { + // else store the other function as my target + initialize(other, MyAllocator(allocator)); + } + } + } + template + function(std::allocator_arg_t, const Allocator &, function && other) FUNC_NOEXCEPT + { + // ignore the allocator because I don't allocate + initialize_empty(); + swap(other); + } + + function & operator=(function other) FUNC_NOEXCEPT + { + swap(other); + return *this; + } + ~function() FUNC_NOEXCEPT + { + manager_storage.manager->call_destroy(manager_storage); + } + + Result operator()(Arguments... arguments) const + { + return call(manager_storage.functor, FUNC_FORWARD(Arguments, arguments)...); + } + + template + void assign(T && functor, const Allocator & allocator) FUNC_TEMPLATE_NOEXCEPT(T, Allocator) + { + function(std::allocator_arg, allocator, functor).swap(*this); + } + + void swap(function & other) FUNC_NOEXCEPT + { + detail::manager_storage_type temp_storage; + other.manager_storage.manager->call_move_and_destroy(temp_storage, FUNC_MOVE(other.manager_storage)); + manager_storage.manager->call_move_and_destroy(other.manager_storage, FUNC_MOVE(manager_storage)); + temp_storage.manager->call_move_and_destroy(manager_storage, FUNC_MOVE(temp_storage)); + + std::swap(call, other.call); + } + + +# ifndef FUNC_NO_RTTI + const std::type_info & target_type() const FUNC_NOEXCEPT + { + return manager_storage.manager->call_type_id(); + } + template + T * target() FUNC_NOEXCEPT + { + return static_cast(manager_storage.manager->call_target(manager_storage, typeid(T))); + } + template + const T * target() const FUNC_NOEXCEPT + { + return static_cast(manager_storage.manager->call_target(manager_storage, typeid(T))); + } +# endif + + operator bool() const FUNC_NOEXCEPT + { + +# ifdef FUNC_NO_EXCEPTIONS + return call != nullptr; +# else + return call != &detail::empty_call; +# endif + } + +private: + detail::manager_storage_type manager_storage; + Result (*call)(const detail::functor_padding &, Arguments...); + + template + void initialize(T functor, Allocator && allocator) + { + call = &detail::function_manager_inplace_specialization::template call; + detail::create_manager(manager_storage, FUNC_FORWARD(Allocator, allocator)); + detail::function_manager_inplace_specialization::store_functor(manager_storage, FUNC_FORWARD(T, functor)); + } + + typedef Result(*Empty_Function_Type)(Arguments...); + void initialize_empty() FUNC_NOEXCEPT + { + typedef std::allocator Allocator; + static_assert(detail::is_inplace_allocated::value, "The empty function should benefit from small functor optimization"); + + detail::create_manager(manager_storage, Allocator()); + detail::function_manager_inplace_specialization::store_functor(manager_storage, nullptr); +# ifdef FUNC_NO_EXCEPTIONS + call = nullptr; +# else + call = &detail::empty_call; +# endif + } +}; + +template +bool operator==(std::nullptr_t, const function & rhs) FUNC_NOEXCEPT +{ + return !rhs; +} +template +bool operator==(const function & lhs, std::nullptr_t) FUNC_NOEXCEPT +{ + return !lhs; +} +template +bool operator!=(std::nullptr_t, const function & rhs) FUNC_NOEXCEPT +{ + return rhs; +} +template +bool operator!=(const function & lhs, std::nullptr_t) FUNC_NOEXCEPT +{ + return lhs; +} + +template +void swap(function & lhs, function & rhs) +{ + lhs.swap(rhs); +} + +} // end namespace func + +namespace std +{ +template +struct uses_allocator, Allocator> + : std::true_type +{ +}; +} + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif +#undef FUNC_NOEXCEPT +#undef FUNC_TEMPLATE_NOEXCEPT +#undef FUNC_FORWARD +#undef FUNC_MOVE +#undef FUNC_CONSTEXPR