Skip to content

Commit

Permalink
Add Spark at_least_n_non_nulls function (#10508)
Browse files Browse the repository at this point in the history
Summary:
This function returns true if there are at least 'n' non-null and non-NaN values.

Spark's implementation: https://github.com/apache/spark/blob/110b558570176d5d5ee5ba85bb071fd66a94a7a0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala#L428

Pull Request resolved: #10508

Reviewed By: xiaoxmeng

Differential Revision: D63261521

Pulled By: Yuhta

fbshipit-source-id: a38c4125defe2d1516a9404fba6b46d5c4f37e14
  • Loading branch information
zhli1142015 authored and facebook-github-bot committed Sep 23, 2024
1 parent 4e45bc5 commit bde87ce
Show file tree
Hide file tree
Showing 8 changed files with 364 additions and 0 deletions.
29 changes: 29 additions & 0 deletions velox/common/base/tests/FloatConstants.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <limits>

namespace facebook::velox::test {

struct FloatConstants {
static constexpr auto kNaND = std::numeric_limits<double>::quiet_NaN();
static constexpr auto kNaNF = std::numeric_limits<float>::quiet_NaN();

static constexpr auto kMaxD = std::numeric_limits<double>::max();
static constexpr auto kMaxF = std::numeric_limits<float>::max();
};
} // namespace facebook::velox::test
13 changes: 13 additions & 0 deletions velox/docs/functions/spark/misc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,19 @@
Miscellaneous Functions
====================================

.. spark:function:: at_least_n_non_nulls(n, value1, value2, ..., valueN) -> bool
Returns true if there are at least ``n`` non-null and non-NaN values,
or false otherwise. ``value1, value2, ..., valueN`` are evaluated lazily.
If ``n`` non-null and non-NaN values are found, the function will stop
evaluating the remaining arguments. If ``n <= 0``, the result is true. null
``n`` is not allowed.
Nested nulls in complex type values are handled as non-nulls. ::

SELECT at_least_n_non_nulls(2, 0, NAN, NULL); -- false
SELECT at_least_n_non_nulls(2, 0, 1.0, NULL); -- true
SELECT at_least_n_non_nulls(2, 0, array(NULL), NULL); -- true

.. spark:function:: monotonically_increasing_id() -> bigint
Returns monotonically increasing 64-bit integers. The generated ID is
Expand Down
4 changes: 4 additions & 0 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
#include "velox/functions/sparksql/StringToMap.h"
#include "velox/functions/sparksql/UnscaledValueFunction.h"
#include "velox/functions/sparksql/Uuid.h"
#include "velox/functions/sparksql/specialforms/AtLeastNNonNulls.h"
#include "velox/functions/sparksql/specialforms/DecimalRound.h"
#include "velox/functions/sparksql/specialforms/MakeDecimal.h"
#include "velox/functions/sparksql/specialforms/SparkCastExpr.h"
Expand Down Expand Up @@ -148,6 +149,9 @@ void registerAllSpecialFormGeneralFunctions() {
"cast", std::make_unique<SparkCastCallToSpecialForm>());
registerFunctionCallToSpecialForm(
"try_cast", std::make_unique<SparkTryCastCallToSpecialForm>());
exec::registerFunctionCallToSpecialForm(
AtLeastNNonNullsCallToSpecialForm::kAtLeastNNonNulls,
std::make_unique<AtLeastNNonNullsCallToSpecialForm>());
}

namespace {
Expand Down
165 changes: 165 additions & 0 deletions velox/functions/sparksql/specialforms/AtLeastNNonNulls.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/functions/sparksql/specialforms/AtLeastNNonNulls.h"
#include "velox/expression/ConstantExpr.h"
#include "velox/expression/SpecialForm.h"

using namespace facebook::velox::exec;

namespace facebook::velox::functions::sparksql {
namespace {
class AtLeastNNonNullsExpr : public SpecialForm {
public:
AtLeastNNonNullsExpr(
TypePtr type,
bool trackCpuUsage,
std::vector<ExprPtr>&& inputs,
int n)
: SpecialForm(
std::move(type),
std::move(inputs),
AtLeastNNonNullsCallToSpecialForm::kAtLeastNNonNulls,
true,
trackCpuUsage),
n_(n) {}

void evalSpecialForm(
const SelectivityVector& rows,
EvalCtx& context,
VectorPtr& result) override {
context.ensureWritable(rows, type(), result);
(*result).clearNulls(rows);
auto flatResult = result->asFlatVector<bool>();
LocalSelectivityVector activeRowsHolder(context, rows);
auto activeRows = activeRowsHolder.get();
VELOX_DCHECK_NOT_NULL(activeRows);
auto values = flatResult->mutableValues(rows.end())->asMutable<uint64_t>();
// If 'n_' <= 0, set result to all true.
if (n_ <= 0) {
bits::orBits(values, rows.asRange().bits(), rows.begin(), rows.end());
return;
}

bits::andWithNegatedBits(
values, rows.asRange().bits(), rows.begin(), rows.end());
// If 'n_' > inputs_.size() - 1, result should be all false.
if (n_ > inputs_.size() - 1) {
return;
}

// Create a temp buffer to track count of non null values for active rows.
auto nonNullCounts =
AlignedBuffer::allocate<int32_t>(activeRows->size(), context.pool(), 0);
auto* rawNonNullCounts = nonNullCounts->asMutable<int32_t>();
for (column_index_t i = 1; i < inputs_.size(); ++i) {
VectorPtr input;
inputs_[i]->eval(*activeRows, context, input);
if (context.errors()) {
context.deselectErrors(*activeRows);
}
VELOX_DYNAMIC_TYPE_DISPATCH_ALL(
updateResultTyped,
inputs_[i]->type()->kind(),
input.get(),
n_,
context,
rawNonNullCounts,
flatResult,
activeRows);
if (activeRows->countSelected() == 0) {
break;
}
}
}

private:
void computePropagatesNulls() override {
propagatesNulls_ = false;
}

template <TypeKind Kind>
void updateResultTyped(
BaseVector* input,
int32_t n,
EvalCtx& context,
int32_t* rawNonNullCounts,
FlatVector<bool>* result,
SelectivityVector* activeRows) {
using T = typename TypeTraits<Kind>::NativeType;
exec::LocalDecodedVector decodedVector(context);
decodedVector.get()->decode(*input, *activeRows);
bool updateBounds = false;
activeRows->applyToSelected([&](auto row) {
bool nonNull = !decodedVector->isNullAt(row);
if constexpr (std::is_same_v<T, double> || std::is_same_v<T, float>) {
nonNull = nonNull && !std::isnan(decodedVector->valueAt<T>(row));
}
if (nonNull) {
rawNonNullCounts[row]++;
if (rawNonNullCounts[row] >= n) {
updateBounds = true;
result->set(row, true);
// Exclude the 'row' from active rows after finding 'n' non-NULL /
// non-NaN values.
activeRows->setValid(row, false);
}
}
});
if (updateBounds) {
activeRows->updateBounds();
}
}

// Result is true if there are at least `n_` non-null and non-NaN values.
const int n_;
};
} // namespace

TypePtr AtLeastNNonNullsCallToSpecialForm::resolveType(
const std::vector<TypePtr>& /*argTypes*/) {
return BOOLEAN();
}

ExprPtr AtLeastNNonNullsCallToSpecialForm::constructSpecialForm(
const TypePtr& type,
std::vector<ExprPtr>&& compiledChildren,
bool trackCpuUsage,
const core::QueryConfig& /*config*/) {
VELOX_USER_CHECK_GT(
compiledChildren.size(),
1,
"AtLeastNNonNulls expects to receive at least 2 arguments.");
VELOX_USER_CHECK(
compiledChildren[0]->type()->isInteger(),
"The first input type should be INTEGER but got {}.",
compiledChildren[0]->type()->toString());
auto constantExpr =
std::dynamic_pointer_cast<exec::ConstantExpr>(compiledChildren[0]);
VELOX_USER_CHECK_NOT_NULL(
constantExpr, "The first parameter should be constant expression.");
VELOX_USER_CHECK(
constantExpr->value()->isConstantEncoding(),
"The first parameter should be wrapped in constant vector.");
auto constVector =
constantExpr->value()->asUnchecked<ConstantVector<int32_t>>();
VELOX_USER_CHECK(
!constVector->isNullAt(0), "The first parameter should not be null.");
const int32_t n = constVector->valueAt(0);
return std::make_shared<AtLeastNNonNullsExpr>(
type, trackCpuUsage, std::move(compiledChildren), n);
}
} // namespace facebook::velox::functions::sparksql
35 changes: 35 additions & 0 deletions velox/functions/sparksql/specialforms/AtLeastNNonNulls.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "velox/expression/FunctionCallToSpecialForm.h"

namespace facebook::velox::functions::sparksql {

class AtLeastNNonNullsCallToSpecialForm
: public exec::FunctionCallToSpecialForm {
public:
TypePtr resolveType(const std::vector<TypePtr>& argTypes) override;

exec::ExprPtr constructSpecialForm(
const TypePtr& type,
std::vector<exec::ExprPtr>&& args,
bool trackCpuUsage,
const core::QueryConfig& config) override;

static constexpr const char* kAtLeastNNonNulls = "at_least_n_non_nulls";
};
} // namespace facebook::velox::functions::sparksql
1 change: 1 addition & 0 deletions velox/functions/sparksql/specialforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

velox_add_library(
velox_functions_spark_specialforms
AtLeastNNonNulls.cpp
DecimalRound.cpp
MakeDecimal.cpp
SparkCastExpr.cpp
Expand Down
116 changes: 116 additions & 0 deletions velox/functions/sparksql/tests/AtLeastNNonNullsTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/functions/sparksql/specialforms/AtLeastNNonNulls.h"
#include "velox/common/base/tests/FloatConstants.h"
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h"

using namespace facebook::velox::test;

namespace facebook::velox::functions::sparksql::test {
namespace {

class AtLeastNNonNullsTest : public SparkFunctionBaseTest {
public:
AtLeastNNonNullsTest() {
// Allow for parsing literal integers as INTEGER, not BIGINT.
options_.parseIntegerAsBigint = false;
}
};

TEST_F(AtLeastNNonNullsTest, basic) {
auto testAtLeastNNonNulls = [&](int32_t n,
const std::vector<VectorPtr>& input,
const VectorPtr& expected) {
std::string func = fmt::format("at_least_n_non_nulls({}", n);
for (auto i = 0; i < input.size(); ++i) {
func += fmt::format(", c{}", i);
}
func += ")";
const auto result = evaluate(func, makeRowVector(input));
assertEqualVectors(expected, result);
};
auto strings = makeNullableFlatVector<StringView>(
{std::nullopt, "1", "", std::nullopt, ""});
auto bools = makeNullableFlatVector<bool>(
{std::nullopt, true, false, std::nullopt, std::nullopt});
auto ints =
makeNullableFlatVector<int32_t>({-1, 0, 1, std::nullopt, std::nullopt});
auto floats = makeNullableFlatVector<float>(
{FloatConstants::kMaxF, FloatConstants::kNaNF, 0.1f, 0.0f, std::nullopt});
auto doubles = makeNullableFlatVector<double>(
{std::log(-2.0),
FloatConstants::kMaxD,
FloatConstants::kNaND,
std::nullopt,
0.1});
auto arrays = makeArrayVectorFromJson<int32_t>(
{"[1, null, 3]", "[1, 2, 3]", "null", "[null]", "[]"});
auto maps = makeMapVectorFromJson<int32_t, int32_t>(
{"{1: 10, 2: null, 3: null}", "{1: 10, 2: 20}", "{1: 2}", "{}", "null"});
auto consts = makeConstant<int32_t>(2, 5);
auto indices = makeIndices({1, 2, 3, 4, 0});
auto dicts = wrapInDictionary(indices, 5, doubles);

auto expected = makeFlatVector<bool>({false, true, true, false, false});
testAtLeastNNonNulls(2, {strings, bools}, expected);

expected = makeFlatVector<bool>({false, false, false, false, false});
testAtLeastNNonNulls(3, {strings, bools}, expected);

expected = makeFlatVector<bool>({true, true, true, true, true});
testAtLeastNNonNulls(0, {strings, bools}, expected);
testAtLeastNNonNulls(-1, {strings, bools}, expected);

expected = makeFlatVector<bool>({true, false, true, true, false});
testAtLeastNNonNulls(1, {floats}, expected);

expected = makeFlatVector<bool>({false, true, false, false, true});
testAtLeastNNonNulls(1, {doubles}, expected);

expected = makeFlatVector<bool>({false, true, true, false, false});
testAtLeastNNonNulls(2, {strings, bools, floats}, expected);

expected = makeFlatVector<bool>({false, true, true, false, false});
testAtLeastNNonNulls(3, {bools, ints, floats, doubles}, expected);

expected = makeFlatVector<bool>({false, false, false, false, false});
testAtLeastNNonNulls(2, {floats, doubles}, expected);

expected = makeFlatVector<bool>({true, false, false, true, false});
testAtLeastNNonNulls(4, {maps, arrays, consts, dicts}, expected);
}

TEST_F(AtLeastNNonNullsTest, error) {
auto input = makeFlatVector<int32_t>({1, 2, 3});

VELOX_ASSERT_USER_THROW(
evaluate("at_least_n_non_nulls(1.0, c0)", makeRowVector({input})),
"The first input type should be INTEGER but got DOUBLE");
VELOX_ASSERT_USER_THROW(
evaluate("at_least_n_non_nulls(1)", makeRowVector({})),
"AtLeastNNonNulls expects to receive at least 2 arguments");
VELOX_ASSERT_USER_THROW(
evaluate("at_least_n_non_nulls(c0, c1)", makeRowVector({input, input})),
"The first parameter should be constant expression");
VELOX_ASSERT_USER_THROW(
evaluate(
"at_least_n_non_nulls(cast(null as int), c0)",
makeRowVector({input})),
"The first parameter should not be null");
}
} // namespace
} // namespace facebook::velox::functions::sparksql::test
Loading

0 comments on commit bde87ce

Please sign in to comment.