-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Spark at_least_n_non_nulls function (#10508)
Summary: This function returns true if there are at least 'n' non-null and non-NaN values. Spark's implementation: https://github.com/apache/spark/blob/110b558570176d5d5ee5ba85bb071fd66a94a7a0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala#L428 Pull Request resolved: #10508 Reviewed By: xiaoxmeng Differential Revision: D63261521 Pulled By: Yuhta fbshipit-source-id: a38c4125defe2d1516a9404fba6b46d5c4f37e14
- Loading branch information
1 parent
4e45bc5
commit bde87ce
Showing
8 changed files
with
364 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#pragma once | ||
|
||
#include <limits> | ||
|
||
namespace facebook::velox::test { | ||
|
||
struct FloatConstants { | ||
static constexpr auto kNaND = std::numeric_limits<double>::quiet_NaN(); | ||
static constexpr auto kNaNF = std::numeric_limits<float>::quiet_NaN(); | ||
|
||
static constexpr auto kMaxD = std::numeric_limits<double>::max(); | ||
static constexpr auto kMaxF = std::numeric_limits<float>::max(); | ||
}; | ||
} // namespace facebook::velox::test |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
165 changes: 165 additions & 0 deletions
165
velox/functions/sparksql/specialforms/AtLeastNNonNulls.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "velox/functions/sparksql/specialforms/AtLeastNNonNulls.h" | ||
#include "velox/expression/ConstantExpr.h" | ||
#include "velox/expression/SpecialForm.h" | ||
|
||
using namespace facebook::velox::exec; | ||
|
||
namespace facebook::velox::functions::sparksql { | ||
namespace { | ||
class AtLeastNNonNullsExpr : public SpecialForm { | ||
public: | ||
AtLeastNNonNullsExpr( | ||
TypePtr type, | ||
bool trackCpuUsage, | ||
std::vector<ExprPtr>&& inputs, | ||
int n) | ||
: SpecialForm( | ||
std::move(type), | ||
std::move(inputs), | ||
AtLeastNNonNullsCallToSpecialForm::kAtLeastNNonNulls, | ||
true, | ||
trackCpuUsage), | ||
n_(n) {} | ||
|
||
void evalSpecialForm( | ||
const SelectivityVector& rows, | ||
EvalCtx& context, | ||
VectorPtr& result) override { | ||
context.ensureWritable(rows, type(), result); | ||
(*result).clearNulls(rows); | ||
auto flatResult = result->asFlatVector<bool>(); | ||
LocalSelectivityVector activeRowsHolder(context, rows); | ||
auto activeRows = activeRowsHolder.get(); | ||
VELOX_DCHECK_NOT_NULL(activeRows); | ||
auto values = flatResult->mutableValues(rows.end())->asMutable<uint64_t>(); | ||
// If 'n_' <= 0, set result to all true. | ||
if (n_ <= 0) { | ||
bits::orBits(values, rows.asRange().bits(), rows.begin(), rows.end()); | ||
return; | ||
} | ||
|
||
bits::andWithNegatedBits( | ||
values, rows.asRange().bits(), rows.begin(), rows.end()); | ||
// If 'n_' > inputs_.size() - 1, result should be all false. | ||
if (n_ > inputs_.size() - 1) { | ||
return; | ||
} | ||
|
||
// Create a temp buffer to track count of non null values for active rows. | ||
auto nonNullCounts = | ||
AlignedBuffer::allocate<int32_t>(activeRows->size(), context.pool(), 0); | ||
auto* rawNonNullCounts = nonNullCounts->asMutable<int32_t>(); | ||
for (column_index_t i = 1; i < inputs_.size(); ++i) { | ||
VectorPtr input; | ||
inputs_[i]->eval(*activeRows, context, input); | ||
if (context.errors()) { | ||
context.deselectErrors(*activeRows); | ||
} | ||
VELOX_DYNAMIC_TYPE_DISPATCH_ALL( | ||
updateResultTyped, | ||
inputs_[i]->type()->kind(), | ||
input.get(), | ||
n_, | ||
context, | ||
rawNonNullCounts, | ||
flatResult, | ||
activeRows); | ||
if (activeRows->countSelected() == 0) { | ||
break; | ||
} | ||
} | ||
} | ||
|
||
private: | ||
void computePropagatesNulls() override { | ||
propagatesNulls_ = false; | ||
} | ||
|
||
template <TypeKind Kind> | ||
void updateResultTyped( | ||
BaseVector* input, | ||
int32_t n, | ||
EvalCtx& context, | ||
int32_t* rawNonNullCounts, | ||
FlatVector<bool>* result, | ||
SelectivityVector* activeRows) { | ||
using T = typename TypeTraits<Kind>::NativeType; | ||
exec::LocalDecodedVector decodedVector(context); | ||
decodedVector.get()->decode(*input, *activeRows); | ||
bool updateBounds = false; | ||
activeRows->applyToSelected([&](auto row) { | ||
bool nonNull = !decodedVector->isNullAt(row); | ||
if constexpr (std::is_same_v<T, double> || std::is_same_v<T, float>) { | ||
nonNull = nonNull && !std::isnan(decodedVector->valueAt<T>(row)); | ||
} | ||
if (nonNull) { | ||
rawNonNullCounts[row]++; | ||
if (rawNonNullCounts[row] >= n) { | ||
updateBounds = true; | ||
result->set(row, true); | ||
// Exclude the 'row' from active rows after finding 'n' non-NULL / | ||
// non-NaN values. | ||
activeRows->setValid(row, false); | ||
} | ||
} | ||
}); | ||
if (updateBounds) { | ||
activeRows->updateBounds(); | ||
} | ||
} | ||
|
||
// Result is true if there are at least `n_` non-null and non-NaN values. | ||
const int n_; | ||
}; | ||
} // namespace | ||
|
||
TypePtr AtLeastNNonNullsCallToSpecialForm::resolveType( | ||
const std::vector<TypePtr>& /*argTypes*/) { | ||
return BOOLEAN(); | ||
} | ||
|
||
ExprPtr AtLeastNNonNullsCallToSpecialForm::constructSpecialForm( | ||
const TypePtr& type, | ||
std::vector<ExprPtr>&& compiledChildren, | ||
bool trackCpuUsage, | ||
const core::QueryConfig& /*config*/) { | ||
VELOX_USER_CHECK_GT( | ||
compiledChildren.size(), | ||
1, | ||
"AtLeastNNonNulls expects to receive at least 2 arguments."); | ||
VELOX_USER_CHECK( | ||
compiledChildren[0]->type()->isInteger(), | ||
"The first input type should be INTEGER but got {}.", | ||
compiledChildren[0]->type()->toString()); | ||
auto constantExpr = | ||
std::dynamic_pointer_cast<exec::ConstantExpr>(compiledChildren[0]); | ||
VELOX_USER_CHECK_NOT_NULL( | ||
constantExpr, "The first parameter should be constant expression."); | ||
VELOX_USER_CHECK( | ||
constantExpr->value()->isConstantEncoding(), | ||
"The first parameter should be wrapped in constant vector."); | ||
auto constVector = | ||
constantExpr->value()->asUnchecked<ConstantVector<int32_t>>(); | ||
VELOX_USER_CHECK( | ||
!constVector->isNullAt(0), "The first parameter should not be null."); | ||
const int32_t n = constVector->valueAt(0); | ||
return std::make_shared<AtLeastNNonNullsExpr>( | ||
type, trackCpuUsage, std::move(compiledChildren), n); | ||
} | ||
} // namespace facebook::velox::functions::sparksql |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#pragma once | ||
|
||
#include "velox/expression/FunctionCallToSpecialForm.h" | ||
|
||
namespace facebook::velox::functions::sparksql { | ||
|
||
class AtLeastNNonNullsCallToSpecialForm | ||
: public exec::FunctionCallToSpecialForm { | ||
public: | ||
TypePtr resolveType(const std::vector<TypePtr>& argTypes) override; | ||
|
||
exec::ExprPtr constructSpecialForm( | ||
const TypePtr& type, | ||
std::vector<exec::ExprPtr>&& args, | ||
bool trackCpuUsage, | ||
const core::QueryConfig& config) override; | ||
|
||
static constexpr const char* kAtLeastNNonNulls = "at_least_n_non_nulls"; | ||
}; | ||
} // namespace facebook::velox::functions::sparksql |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
116 changes: 116 additions & 0 deletions
116
velox/functions/sparksql/tests/AtLeastNNonNullsTest.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#include "velox/functions/sparksql/specialforms/AtLeastNNonNulls.h" | ||
#include "velox/common/base/tests/FloatConstants.h" | ||
#include "velox/common/base/tests/GTestUtils.h" | ||
#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" | ||
|
||
using namespace facebook::velox::test; | ||
|
||
namespace facebook::velox::functions::sparksql::test { | ||
namespace { | ||
|
||
class AtLeastNNonNullsTest : public SparkFunctionBaseTest { | ||
public: | ||
AtLeastNNonNullsTest() { | ||
// Allow for parsing literal integers as INTEGER, not BIGINT. | ||
options_.parseIntegerAsBigint = false; | ||
} | ||
}; | ||
|
||
TEST_F(AtLeastNNonNullsTest, basic) { | ||
auto testAtLeastNNonNulls = [&](int32_t n, | ||
const std::vector<VectorPtr>& input, | ||
const VectorPtr& expected) { | ||
std::string func = fmt::format("at_least_n_non_nulls({}", n); | ||
for (auto i = 0; i < input.size(); ++i) { | ||
func += fmt::format(", c{}", i); | ||
} | ||
func += ")"; | ||
const auto result = evaluate(func, makeRowVector(input)); | ||
assertEqualVectors(expected, result); | ||
}; | ||
auto strings = makeNullableFlatVector<StringView>( | ||
{std::nullopt, "1", "", std::nullopt, ""}); | ||
auto bools = makeNullableFlatVector<bool>( | ||
{std::nullopt, true, false, std::nullopt, std::nullopt}); | ||
auto ints = | ||
makeNullableFlatVector<int32_t>({-1, 0, 1, std::nullopt, std::nullopt}); | ||
auto floats = makeNullableFlatVector<float>( | ||
{FloatConstants::kMaxF, FloatConstants::kNaNF, 0.1f, 0.0f, std::nullopt}); | ||
auto doubles = makeNullableFlatVector<double>( | ||
{std::log(-2.0), | ||
FloatConstants::kMaxD, | ||
FloatConstants::kNaND, | ||
std::nullopt, | ||
0.1}); | ||
auto arrays = makeArrayVectorFromJson<int32_t>( | ||
{"[1, null, 3]", "[1, 2, 3]", "null", "[null]", "[]"}); | ||
auto maps = makeMapVectorFromJson<int32_t, int32_t>( | ||
{"{1: 10, 2: null, 3: null}", "{1: 10, 2: 20}", "{1: 2}", "{}", "null"}); | ||
auto consts = makeConstant<int32_t>(2, 5); | ||
auto indices = makeIndices({1, 2, 3, 4, 0}); | ||
auto dicts = wrapInDictionary(indices, 5, doubles); | ||
|
||
auto expected = makeFlatVector<bool>({false, true, true, false, false}); | ||
testAtLeastNNonNulls(2, {strings, bools}, expected); | ||
|
||
expected = makeFlatVector<bool>({false, false, false, false, false}); | ||
testAtLeastNNonNulls(3, {strings, bools}, expected); | ||
|
||
expected = makeFlatVector<bool>({true, true, true, true, true}); | ||
testAtLeastNNonNulls(0, {strings, bools}, expected); | ||
testAtLeastNNonNulls(-1, {strings, bools}, expected); | ||
|
||
expected = makeFlatVector<bool>({true, false, true, true, false}); | ||
testAtLeastNNonNulls(1, {floats}, expected); | ||
|
||
expected = makeFlatVector<bool>({false, true, false, false, true}); | ||
testAtLeastNNonNulls(1, {doubles}, expected); | ||
|
||
expected = makeFlatVector<bool>({false, true, true, false, false}); | ||
testAtLeastNNonNulls(2, {strings, bools, floats}, expected); | ||
|
||
expected = makeFlatVector<bool>({false, true, true, false, false}); | ||
testAtLeastNNonNulls(3, {bools, ints, floats, doubles}, expected); | ||
|
||
expected = makeFlatVector<bool>({false, false, false, false, false}); | ||
testAtLeastNNonNulls(2, {floats, doubles}, expected); | ||
|
||
expected = makeFlatVector<bool>({true, false, false, true, false}); | ||
testAtLeastNNonNulls(4, {maps, arrays, consts, dicts}, expected); | ||
} | ||
|
||
TEST_F(AtLeastNNonNullsTest, error) { | ||
auto input = makeFlatVector<int32_t>({1, 2, 3}); | ||
|
||
VELOX_ASSERT_USER_THROW( | ||
evaluate("at_least_n_non_nulls(1.0, c0)", makeRowVector({input})), | ||
"The first input type should be INTEGER but got DOUBLE"); | ||
VELOX_ASSERT_USER_THROW( | ||
evaluate("at_least_n_non_nulls(1)", makeRowVector({})), | ||
"AtLeastNNonNulls expects to receive at least 2 arguments"); | ||
VELOX_ASSERT_USER_THROW( | ||
evaluate("at_least_n_non_nulls(c0, c1)", makeRowVector({input, input})), | ||
"The first parameter should be constant expression"); | ||
VELOX_ASSERT_USER_THROW( | ||
evaluate( | ||
"at_least_n_non_nulls(cast(null as int), c0)", | ||
makeRowVector({input})), | ||
"The first parameter should not be null"); | ||
} | ||
} // namespace | ||
} // namespace facebook::velox::functions::sparksql::test |
Oops, something went wrong.