Skip to content

Commit

Permalink
Add complex type inputs to array_except, array_intersect and arrays_o…
Browse files Browse the repository at this point in the history
…verlap (facebookincubator#10743)

Summary: Pull Request resolved: facebookincubator#10743

Reviewed By: kevinwilfong

Differential Revision: D61246379

Pulled By: kewang1024

fbshipit-source-id: 1f2fb793a43e18abe51bfba97ba1558b8204c563
  • Loading branch information
kewang1024 authored and facebook-github-bot committed Aug 22, 2024
1 parent 8b75147 commit 2796200
Show file tree
Hide file tree
Showing 5 changed files with 350 additions and 27 deletions.
125 changes: 98 additions & 27 deletions velox/functions/prestosql/ArrayIntersectExcept.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,23 @@

namespace facebook::velox::functions {
namespace {
constexpr vector_size_t kInitialSetSize{128};

template <typename T>
struct SetWithNull {
SetWithNull(vector_size_t initialSetSize = kInitialSetSize) {
set.reserve(initialSetSize);
}

bool insert(const DecodedVector* decodedElements, vector_size_t offset) {
return set.insert(decodedElements->valueAt<T>(offset)).second;
}

size_t count(const DecodedVector* decodedElements, vector_size_t offset)
const {
return set.count(decodedElements->valueAt<T>(offset));
}

void reset() {
set.clear();
hasNull = false;
Expand All @@ -37,7 +48,65 @@ struct SetWithNull {

util::floating_point::HashSetNaNAware<T> set;
bool hasNull{false};
static constexpr vector_size_t kInitialSetSize{128};
};

struct ComplexTypeEntry {
const uint64_t hash;
const BaseVector* baseVector;
const vector_size_t index;
};

template <>
struct SetWithNull<ComplexTypeEntry> {
struct Hash {
size_t operator()(const ComplexTypeEntry& entry) const {
return entry.hash;
}
};

struct EqualTo {
bool operator()(const ComplexTypeEntry& left, const ComplexTypeEntry& right)
const {
return left.baseVector
->equalValueAt(
right.baseVector,
left.index,
right.index,
CompareFlags::NullHandlingMode::kNullAsValue)
.value();
}
};

folly::F14FastSet<ComplexTypeEntry, Hash, EqualTo> set;
bool hasNull{false};

SetWithNull(vector_size_t initialSetSize = kInitialSetSize) {
set.reserve(initialSetSize);
}

bool insert(const DecodedVector* decodedElements, vector_size_t offset) {
const auto vector = decodedElements->base();
const auto index = decodedElements->index(offset);
const uint64_t hash = vector->hashValueAt(index);
return set.insert(ComplexTypeEntry{hash, vector, index}).second;
}

size_t count(const DecodedVector* decodedElements, vector_size_t offset)
const {
const auto vector = decodedElements->base();
const auto index = decodedElements->index(offset);
const uint64_t hash = vector->hashValueAt(index);
return set.count(ComplexTypeEntry{hash, vector, index});
}

void reset() {
set.clear();
hasNull = false;
}

bool empty() const {
return !hasNull && set.empty();
}
};

// Generates a set based on the elements of an ArrayVector. Note that we take
Expand All @@ -57,7 +126,7 @@ void generateSet(
if (arrayElements->isNullAt(i)) {
rightSet.hasNull = true;
} else {
rightSet.set.insert(arrayElements->template valueAt<T>(i));
rightSet.insert(arrayElements, i);
}
}
}
Expand Down Expand Up @@ -186,19 +255,17 @@ class ArrayIntersectExceptFunction : public exec::VectorFunction {
}
}
} else {
auto val = decodedLeftElements->valueAt<T>(i);
// For array_intersect, add the element if it is found (not found
// for array_except) in the right-hand side, and wasn't added already
// (check outputSet).
bool addValue = false;
if constexpr (isIntersect) {
addValue = rightSet.set.count(val) > 0;
addValue = rightSet.count(decodedLeftElements, i) > 0;
} else {
addValue = rightSet.set.count(val) == 0;
addValue = rightSet.count(decodedLeftElements, i) == 0;
}
if (addValue) {
auto it = outputSet.set.insert(val);
if (it.second) {
if (outputSet.insert(decodedLeftElements, i)) {
rawNewIndices[indicesCursor++] = i;
}
}
Expand Down Expand Up @@ -294,7 +361,7 @@ class ArraysOverlapFunction : public exec::VectorFunction {
hasNull = true;
continue;
}
if (rightSet.set.count(decodedLeftElements->valueAt<T>(i)) > 0) {
if (rightSet.count(decodedLeftElements, i) > 0) {
// Found an overlapping element. Add to result set.
resultBoolVector->set(row, true);
return;
Expand Down Expand Up @@ -396,7 +463,10 @@ SetWithNull<T> validateConstantVectorAndGenerateSet(
template <bool isIntersect, TypeKind kind>
std::shared_ptr<exec::VectorFunction> createTypedArraysIntersectExcept(
const std::vector<exec::VectorFunctionArg>& inputArgs) {
using T = typename TypeTraits<kind>::NativeType;
using T = std::conditional_t<
TypeTraits<kind>::isPrimitiveType,
typename TypeTraits<kind>::NativeType,
ComplexTypeEntry>;

VELOX_CHECK_EQ(inputArgs.size(), 2);
BaseVector* rhs = inputArgs[1].constantValue.get();
Expand Down Expand Up @@ -424,7 +494,7 @@ std::shared_ptr<exec::VectorFunction> createArrayIntersect(
validateMatchingArrayTypes(inputArgs, name, 2);
auto elementType = inputArgs.front().type->childAt(0);

return VELOX_DYNAMIC_SCALAR_TEMPLATE_TYPE_DISPATCH(
return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH(
createTypedArraysIntersectExcept,
/* isIntersect */ true,
elementType->kind(),
Expand All @@ -438,26 +508,23 @@ std::shared_ptr<exec::VectorFunction> createArrayExcept(
validateMatchingArrayTypes(inputArgs, name, 2);
auto elementType = inputArgs.front().type->childAt(0);

return VELOX_DYNAMIC_SCALAR_TEMPLATE_TYPE_DISPATCH(
return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH(
createTypedArraysIntersectExcept,
/* isIntersect */ false,
elementType->kind(),
inputArgs);
}

std::vector<std::shared_ptr<exec::FunctionSignature>> signatures(
const std::string& returnTypeTemplate) {
std::vector<std::shared_ptr<exec::FunctionSignature>> signatures;
for (const auto& type : exec::primitiveTypeNames()) {
signatures.push_back(
exec::FunctionSignatureBuilder()
.returnType(
fmt::format(fmt::runtime(returnTypeTemplate.c_str()), type))
.argumentType(fmt::format("array({})", type))
.argumentType(fmt::format("array({})", type))
.build());
}
return signatures;
const std::string& returnType) {
return std::vector<std::shared_ptr<exec::FunctionSignature>>{
exec::FunctionSignatureBuilder()
.typeVariable("T")
.returnType(returnType)
.argumentType("array(T)")
.argumentType("array(T)")
.build(),
};
}

template <TypeKind kind>
Expand All @@ -466,7 +533,11 @@ const std::shared_ptr<exec::VectorFunction> createTypedArraysOverlap(
VELOX_CHECK_EQ(inputArgs.size(), 2);
auto left = inputArgs[0].constantValue.get();
auto right = inputArgs[1].constantValue.get();
using T = typename TypeTraits<kind>::NativeType;
using T = std::conditional_t<
TypeTraits<kind>::isPrimitiveType,
typename TypeTraits<kind>::NativeType,
ComplexTypeEntry>;

if (left == nullptr && right == nullptr) {
return std::make_shared<ArraysOverlapFunction<T>>();
}
Expand All @@ -484,7 +555,7 @@ std::shared_ptr<exec::VectorFunction> createArraysOverlapFunction(
validateMatchingArrayTypes(inputArgs, name, 2);
auto elementType = inputArgs.front().type->childAt(0);

return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH(
return VELOX_DYNAMIC_TYPE_DISPATCH(
createTypedArraysOverlap, elementType->kind(), inputArgs);
}
} // namespace
Expand All @@ -496,11 +567,11 @@ VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(

VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
udf_array_intersect,
signatures("array({})"),
signatures("array(T)"),
createArrayIntersect);

VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
udf_array_except,
signatures("array({})"),
signatures("array(T)"),
createArrayExcept);
} // namespace facebook::velox::functions
64 changes: 64 additions & 0 deletions velox/functions/prestosql/tests/ArrayExceptTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ using namespace facebook::velox::functions::test;

namespace {

template <typename TKey, typename TValue>
using Pair = std::pair<TKey, std::optional<TValue>>;

class ArrayExceptTest : public FunctionBaseTest {
protected:
void testExpr(
Expand Down Expand Up @@ -278,6 +281,67 @@ TEST_F(ArrayExceptTest, varbinary) {
testExpr(expected, "array_except(c0, c1)", {right, left});
}

TEST_F(ArrayExceptTest, complexTypeArray) {
auto left = makeNestedArrayVectorFromJson<int32_t>({
"[null, [1, 2, 3], [null, null]]",
"[[1], [2], []]",
"[[1, null, 3]]",
"[[1, null, 3]]",
});

auto right = makeNestedArrayVectorFromJson<int32_t>({
"[[1, 2, 3]]",
"[[1]]",
"[[1, null, 3], [1, 2]]",
"[[1, null, 3, null]]",
});

auto expected = makeNestedArrayVectorFromJson<int32_t>({
"[null, [null, null]]",
"[[2], []]",
"[]",
"[[1, null, 3]]",
});
testExpr(expected, "array_except(c0, c1)", {left, right});
}

TEST_F(ArrayExceptTest, complexTypeMap) {
std::vector<Pair<StringView, int64_t>> a{{"blue", 1}, {"red", 2}};
std::vector<Pair<StringView, int64_t>> b{{"blue", 2}, {"red", 2}};
std::vector<Pair<StringView, int64_t>> c{{"green", std::nullopt}};
std::vector<Pair<StringView, int64_t>> d{{"yellow", 4}, {"purple", 5}};
std::vector<std::vector<std::vector<Pair<StringView, int64_t>>>> leftData{
{b, a}, {b}, {c, a}};
std::vector<std::vector<std::vector<Pair<StringView, int64_t>>>> rightData{
{a, b}, {}, {a}};
std::vector<std::vector<std::vector<Pair<StringView, int64_t>>>> expectedData{
{}, {b}, {c}};

auto left = makeArrayOfMapVector<StringView, int64_t>(leftData);
auto right = makeArrayOfMapVector<StringView, int64_t>(rightData);
auto expected = makeArrayOfMapVector<StringView, int64_t>(expectedData);

testExpr(expected, "array_except(c0, c1)", {left, right});
}

TEST_F(ArrayExceptTest, complexTypeRow) {
RowTypePtr rowType = ROW({INTEGER(), VARCHAR()});

using ArrayOfRow = std::vector<std::optional<std::tuple<int, std::string>>>;
std::vector<ArrayOfRow> leftData = {
{{{1, "red"}}, {{2, "blue"}}, {{3, "green"}}},
{{{1, "red"}}, {{2, "blue"}}, {}},
{{{1, "red"}}, std::nullopt, std::nullopt}};
std::vector<ArrayOfRow> rightData = {
{{{2, "blue"}}, {{1, "red"}}}, {{}, {{1, "green"}}}, {{{1, "red"}}}};
std::vector<ArrayOfRow> expectedData = {
{{{3, "green"}}}, {{{1, "red"}}, {{2, "blue"}}}, {std::nullopt}};
auto left = makeArrayOfRowVector(leftData, rowType);
auto right = makeArrayOfRowVector(rightData, rowType);
auto expected = makeArrayOfRowVector(expectedData, rowType);
testExpr(expected, "array_except(c0, c1)", {left, right});
}

// When one of the arrays is constant.
TEST_F(ArrayExceptTest, constant) {
auto array1 = makeNullableArrayVector<int32_t>({
Expand Down
59 changes: 59 additions & 0 deletions velox/functions/prestosql/tests/ArrayIntersectTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ using namespace facebook::velox::functions::test;

namespace {

template <typename TKey, typename TValue>
using Pair = std::pair<TKey, std::optional<TValue>>;

class ArrayIntersectTest : public FunctionBaseTest {
protected:
void testExpr(
Expand Down Expand Up @@ -257,6 +260,62 @@ TEST_F(ArrayIntersectTest, varbinary) {
testExpr(expected, "array_intersect(c0, c1)", {right, left});
}

TEST_F(ArrayIntersectTest, complexTypeArray) {
auto left = makeNestedArrayVectorFromJson<int32_t>({
"[null, [1, 2, 3], [null, null]]",
"[[1], [2], []]",
"[[1, null, 3]]",
"[[1, null, 3]]",
});

auto right = makeNestedArrayVectorFromJson<int32_t>({
"[[1, 2, 3]]",
"[[1]]",
"[[1, null, 3], [1, 2]]",
"[[1, null, 3, null]]",
});

auto expected = makeNestedArrayVectorFromJson<int32_t>(
{"[[1, 2, 3]]", "[[1]]", "[[1, null, 3]]", "[]"});
testExpr(expected, "array_intersect(c0, c1)", {left, right});
}

TEST_F(ArrayIntersectTest, complexTypeMap) {
std::vector<Pair<StringView, int64_t>> a{{"blue", 1}, {"red", 2}};
std::vector<Pair<StringView, int64_t>> b{{"green", std::nullopt}};
std::vector<Pair<StringView, int64_t>> c{{"yellow", 4}, {"purple", 5}};
std::vector<std::vector<std::vector<Pair<StringView, int64_t>>>> leftData{
{b, a}, {b}, {c, a}};
std::vector<std::vector<std::vector<Pair<StringView, int64_t>>>> rightData{
{a, b}, {}, {a}};
std::vector<std::vector<std::vector<Pair<StringView, int64_t>>>> expectedData{
{b, a}, {}, {a}};

auto left = makeArrayOfMapVector<StringView, int64_t>(leftData);
auto right = makeArrayOfMapVector<StringView, int64_t>(rightData);
auto expected = makeArrayOfMapVector<StringView, int64_t>(expectedData);

testExpr(expected, "array_intersect(c0, c1)", {left, right});
}

TEST_F(ArrayIntersectTest, complexTypeRow) {
RowTypePtr rowType = ROW({INTEGER(), VARCHAR()});

using ArrayOfRow = std::vector<std::optional<std::tuple<int, std::string>>>;
std::vector<ArrayOfRow> leftData = {
{{{1, "red"}}, {{2, "blue"}}, {{3, "green"}}},
{{{1, "red"}}, {{2, "blue"}}, {}},
{{{1, "red"}}, std::nullopt, std::nullopt}};
std::vector<ArrayOfRow> rightData = {
{{{2, "blue"}}, {{1, "red"}}}, {{}, {{1, "green"}}}, {{{1, "red"}}}};
std::vector<ArrayOfRow> expectedData = {
{{{1, "red"}}, {{2, "blue"}}}, {{}}, {{{1, "red"}}}};
auto left = makeArrayOfRowVector(leftData, rowType);
auto right = makeArrayOfRowVector(rightData, rowType);
auto expected = makeArrayOfRowVector(expectedData, rowType);
testExpr(expected, "array_intersect(c0, c1)", {left, right});
}

// When one of the arrays is constant.
TEST_F(ArrayIntersectTest, constant) {
auto array1 = makeNullableArrayVector<int32_t>({
Expand Down
Loading

0 comments on commit 2796200

Please sign in to comment.