diff --git a/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala b/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala index e6a8bf2c8f20..a00bcae1ce70 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala @@ -159,7 +159,13 @@ object VeloxIntermediateData { * row_constructor_with_null. */ def getRowConstructFuncName(aggFunc: AggregateFunction): String = aggFunc match { - case _: Average | _: Sum if aggFunc.dataType.isInstanceOf[DecimalType] => "row_constructor" + case _: Average | _: Sum if aggFunc.dataType.isInstanceOf[DecimalType] => + "row_constructor" + // For agg function min_by/max_by, it needs to keep rows with null value but non-null + // comparison, such as . So we set the struct to null when all of the arguments + // are null + case _: MaxMinBy => + "row_constructor_with_all_null" case _ => "row_constructor_with_null" } diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala index 394c4e01651e..70fff52b84d6 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala @@ -27,6 +27,8 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu override protected val resourcePath: String = "/tpch-data-parquet-velox" override protected val fileFormat: String = "parquet" + import testImplicits._ + override def beforeAll(): Unit = { super.beforeAll() createTPCHNotNullTables() @@ -188,6 +190,22 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu } } + test("min_by/max_by") { + withTempPath { + path => + Seq((5: Integer, 6: Integer), (null: Integer, 11: Integer), (null: Integer, 5: Integer)) + .toDF("a", "b") + .write + .parquet(path.getCanonicalPath) + spark.read + .parquet(path.getCanonicalPath) + .createOrReplaceTempView("test") + runQueryAndCompare("select min_by(a, b), max_by(a, b) from test") { + checkGlutenOperatorMatch[HashAggregateExecTransformer] + } + } + } + test("groupby") { val df = runQueryAndCompare( "select l_orderkey, sum(l_partkey) as sum from lineitem " + diff --git a/cpp/velox/operators/functions/RegistrationAllFunctions.cc b/cpp/velox/operators/functions/RegistrationAllFunctions.cc index 2d2e820f1d03..c77fa47e5bff 100644 --- a/cpp/velox/operators/functions/RegistrationAllFunctions.cc +++ b/cpp/velox/operators/functions/RegistrationAllFunctions.cc @@ -16,6 +16,7 @@ */ #include "operators/functions/RegistrationAllFunctions.h" #include "operators/functions/Arithmetic.h" +#include "operators/functions/RowConstructorWithAllNull.h" #include "operators/functions/RowConstructorWithNull.h" #include "operators/functions/RowFunctionWithNull.h" @@ -47,11 +48,19 @@ void registerFunctionOverwrite() { velox::exec::registerVectorFunction( "row_constructor_with_null", std::vector>{}, - std::make_unique(), - RowFunctionWithNull::metadata()); + std::make_unique>(), + RowFunctionWithNull::metadata()); velox::exec::registerFunctionCallToSpecialForm( RowConstructorWithNullCallToSpecialForm::kRowConstructorWithNull, std::make_unique()); + velox::exec::registerVectorFunction( + "row_constructor_with_all_null", + std::vector>{}, + std::make_unique>(), + RowFunctionWithNull::metadata()); + velox::exec::registerFunctionCallToSpecialForm( + RowConstructorWithAllNullCallToSpecialForm::kRowConstructorWithAllNull, + std::make_unique()); velox::functions::sparksql::registerBitwiseFunctions("spark_"); } } // namespace diff --git a/cpp/velox/operators/functions/RowConstructorWithAllNull.h b/cpp/velox/operators/functions/RowConstructorWithAllNull.h new file mode 100644 index 000000000000..dfc79e1a977b --- /dev/null +++ b/cpp/velox/operators/functions/RowConstructorWithAllNull.h @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "RowConstructorWithNull.h" + +namespace gluten { +class RowConstructorWithAllNullCallToSpecialForm : public RowConstructorWithNullCallToSpecialForm { + public: + static constexpr const char* kRowConstructorWithAllNull = "row_constructor_with_all_null"; + + protected: + facebook::velox::exec::ExprPtr constructSpecialForm( + const std::string& name, + const facebook::velox::TypePtr& type, + std::vector&& compiledChildren, + bool trackCpuUsage, + const facebook::velox::core::QueryConfig& config) { + return constructSpecialForm(kRowConstructorWithAllNull, type, std::move(compiledChildren), trackCpuUsage, config); + } +}; +} // namespace gluten diff --git a/cpp/velox/operators/functions/RowFunctionWithNull.h b/cpp/velox/operators/functions/RowFunctionWithNull.h index 9ed6bc27792a..4131fb472ddd 100644 --- a/cpp/velox/operators/functions/RowFunctionWithNull.h +++ b/cpp/velox/operators/functions/RowFunctionWithNull.h @@ -23,8 +23,10 @@ namespace gluten { /** - * A customized RowFunction to set struct as null when one of its argument is null. + * @tparam allNull If true, set struct as null when all of arguments are all, else will + * set it null when one of its arguments is null. */ +template class RowFunctionWithNull final : public facebook::velox::exec::VectorFunction { public: void apply( @@ -42,13 +44,26 @@ class RowFunctionWithNull final : public facebook::velox::exec::VectorFunction { rows.applyToSelected([&](facebook::velox::vector_size_t i) { facebook::velox::bits::clearNull(nullsPtr, i); if (!facebook::velox::bits::isBitNull(nullsPtr, i)) { + int argsNullCnt = 0; for (size_t c = 0; c < argsCopy.size(); c++) { auto arg = argsCopy[c].get(); if (arg->mayHaveNulls() && arg->isNullAt(i)) { - // If any argument of the struct is null, set the struct as null. + // For row_constructor_with_null, if any argument of the struct is null, + // set the struct as null. + if constexpr (!allNull) { + facebook::velox::bits::setNull(nullsPtr, i, true); + cntNull++; + break; + } else { + argsNullCnt++; + } + } + } + // For row_constructor_with_all_null, set the struct to be null when all arguments are all + if constexpr (allNull) { + if (argsNullCnt == argsCopy.size()) { facebook::velox::bits::setNull(nullsPtr, i, true); cntNull++; - break; } } }