Skip to content

Commit

Permalink
[GLUTEN-4652] Fix min_by/max_by result mismatch
Browse files Browse the repository at this point in the history
  • Loading branch information
yma11 committed Apr 28, 2024
1 parent dba95de commit 789591b
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ object VeloxIntermediateData {
* row_constructor_with_null.
*/
def getRowConstructFuncName(aggFunc: AggregateFunction): String = aggFunc match {
case _: Average | _: Sum if aggFunc.dataType.isInstanceOf[DecimalType] => "row_constructor"
case _: Average | _: Sum if aggFunc.dataType.isInstanceOf[DecimalType] =>
"row_constructor"
case _: MaxMinBy =>
"row_constructor_with_all_null"
case _ => "row_constructor_with_null"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu
override protected val resourcePath: String = "/tpch-data-parquet-velox"
override protected val fileFormat: String = "parquet"

import testImplicits._

override def beforeAll(): Unit = {
super.beforeAll()
createTPCHNotNullTables()
Expand Down Expand Up @@ -188,6 +190,22 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu
}
}

test("min_by/max_by") {
withTempPath {
path =>
Seq((5: Integer, 6: Integer), (null: Integer, 11: Integer), (null: Integer, 5: Integer))
.toDF("a", "b")
.write
.parquet(path.getCanonicalPath)
spark.read
.parquet(path.getCanonicalPath)
.createOrReplaceTempView("test")
runQueryAndCompare("select min_by(a, b), max_by(a, b) from test") {
checkGlutenOperatorMatch[HashAggregateExecTransformer]
}
}
}

test("groupby") {
val df = runQueryAndCompare(
"select l_orderkey, sum(l_partkey) as sum from lineitem " +
Expand Down
1 change: 1 addition & 0 deletions cpp/velox/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ set(VELOX_SRCS
memory/VeloxMemoryManager.cc
operators/functions/RegistrationAllFunctions.cc
operators/functions/RowConstructorWithNull.cc
operators/functions/RowConstructorWithAllNull.cc
operators/functions/SparkTokenizer.cc
operators/serializer/VeloxColumnarToRowConverter.cc
operators/serializer/VeloxColumnarBatchSerializer.cc
Expand Down
13 changes: 11 additions & 2 deletions cpp/velox/operators/functions/RegistrationAllFunctions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
#include "operators/functions/RegistrationAllFunctions.h"
#include "operators/functions/Arithmetic.h"
#include "operators/functions/RowConstructorWithAllNull.h"
#include "operators/functions/RowConstructorWithNull.h"
#include "operators/functions/RowFunctionWithNull.h"

Expand Down Expand Up @@ -47,11 +48,19 @@ void registerFunctionOverwrite() {
velox::exec::registerVectorFunction(
"row_constructor_with_null",
std::vector<std::shared_ptr<velox::exec::FunctionSignature>>{},
std::make_unique<RowFunctionWithNull>(),
RowFunctionWithNull::metadata());
std::make_unique<RowFunctionWithNull</*allNull=*/false>>(),
RowFunctionWithNull</*allNull=*/false>::metadata());
velox::exec::registerFunctionCallToSpecialForm(
RowConstructorWithNullCallToSpecialForm::kRowConstructorWithNull,
std::make_unique<RowConstructorWithNullCallToSpecialForm>());
velox::exec::registerVectorFunction(
"row_constructor_with_all_null",
std::vector<std::shared_ptr<velox::exec::FunctionSignature>>{},
std::make_unique<RowFunctionWithNull</*allNull=*/true>>(),
RowFunctionWithNull</*allNull=*/true>::metadata());
velox::exec::registerFunctionCallToSpecialForm(
RowConstructorWithAllNullCallToSpecialForm::kRowConstructorWithAllNull,
std::make_unique<RowConstructorWithAllNullCallToSpecialForm>());
velox::functions::sparksql::registerBitwiseFunctions("spark_");
}
} // namespace
Expand Down
63 changes: 63 additions & 0 deletions cpp/velox/operators/functions/RowConstructorWithAllNull.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "RowConstructorWithAllNull.h"
#include "velox/expression/VectorFunction.h"

namespace gluten {
facebook::velox::TypePtr RowConstructorWithAllNullCallToSpecialForm::resolveType(
const std::vector<facebook::velox::TypePtr>& argTypes) {
auto numInput = argTypes.size();
std::vector<std::string> names(numInput);
std::vector<facebook::velox::TypePtr> types(numInput);
for (auto i = 0; i < numInput; i++) {
types[i] = argTypes[i];
names[i] = fmt::format("c{}", i + 1);
}
return facebook::velox::ROW(std::move(names), std::move(types));
}

facebook::velox::exec::ExprPtr RowConstructorWithAllNullCallToSpecialForm::constructSpecialForm(
const std::string& name,
const facebook::velox::TypePtr& type,
std::vector<facebook::velox::exec::ExprPtr>&& compiledChildren,
bool trackCpuUsage,
const facebook::velox::core::QueryConfig& config) {
auto [function, metadata] = facebook::velox::exec::vectorFunctionFactories().withRLock(
[&config, &name](auto& functionMap) -> std::pair<
std::shared_ptr<facebook::velox::exec::VectorFunction>,
facebook::velox::exec::VectorFunctionMetadata> {
auto functionIterator = functionMap.find(name);
if (functionIterator != functionMap.end()) {
return {functionIterator->second.factory(name, {}, config), functionIterator->second.metadata};
} else {
VELOX_FAIL("Function {} is not registered.", name);
}
});

return std::make_shared<facebook::velox::exec::Expr>(
type, std::move(compiledChildren), function, metadata, name, trackCpuUsage);
}

facebook::velox::exec::ExprPtr RowConstructorWithAllNullCallToSpecialForm::constructSpecialForm(
const facebook::velox::TypePtr& type,
std::vector<facebook::velox::exec::ExprPtr>&& compiledChildren,
bool trackCpuUsage,
const facebook::velox::core::QueryConfig& config) {
return constructSpecialForm(kRowConstructorWithAllNull, type, std::move(compiledChildren), trackCpuUsage, config);
}
} // namespace gluten
44 changes: 44 additions & 0 deletions cpp/velox/operators/functions/RowConstructorWithAllNull.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "velox/expression/FunctionCallToSpecialForm.h"
#include "velox/expression/SpecialForm.h"

namespace gluten {
class RowConstructorWithAllNullCallToSpecialForm : public facebook::velox::exec::FunctionCallToSpecialForm {
public:
facebook::velox::TypePtr resolveType(const std::vector<facebook::velox::TypePtr>& argTypes) override;

facebook::velox::exec::ExprPtr constructSpecialForm(
const facebook::velox::TypePtr& type,
std::vector<facebook::velox::exec::ExprPtr>&& compiledChildren,
bool trackCpuUsage,
const facebook::velox::core::QueryConfig& config) override;

static constexpr const char* kRowConstructorWithAllNull = "row_constructor_with_all_null";

protected:
facebook::velox::exec::ExprPtr constructSpecialForm(
const std::string& name,
const facebook::velox::TypePtr& type,
std::vector<facebook::velox::exec::ExprPtr>&& compiledChildren,
bool trackCpuUsage,
const facebook::velox::core::QueryConfig& config);
};
} // namespace gluten
23 changes: 18 additions & 5 deletions cpp/velox/operators/functions/RowFunctionWithNull.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
namespace gluten {

/**
* A customized RowFunction to set struct as null when one of its argument is null.
* @tparam allNull If true, set struct as null when all of arguments are all, else will
* set it null when one of its arguments is null.
*/
template <bool allNull>
class RowFunctionWithNull final : public facebook::velox::exec::VectorFunction {
public:
void apply(
Expand All @@ -42,15 +44,26 @@ class RowFunctionWithNull final : public facebook::velox::exec::VectorFunction {
rows.applyToSelected([&](facebook::velox::vector_size_t i) {
facebook::velox::bits::clearNull(nullsPtr, i);
if (!facebook::velox::bits::isBitNull(nullsPtr, i)) {
int argsNullCnt = 0;
for (size_t c = 0; c < argsCopy.size(); c++) {
auto arg = argsCopy[c].get();
if (arg->mayHaveNulls() && arg->isNullAt(i)) {
// If any argument of the struct is null, set the struct as null.
facebook::velox::bits::setNull(nullsPtr, i, true);
cntNull++;
break;
// For row_constructor_with_null, if any argument of the struct is null,
// set the struct as null.
if (!allNull) {
facebook::velox::bits::setNull(nullsPtr, i, true);
cntNull++;
break;
} else {
argsNullCnt++;
}
}
}
// For row_constructor_with_all_null, set the struct to be null when all arguments are all
if (allNull && argsNullCnt == argsCopy.size()) {
facebook::velox::bits::setNull(nullsPtr, i, true);
cntNull++;
}
}
});

Expand Down

0 comments on commit 789591b

Please sign in to comment.