From 2796200da36c7d033d45066e00ed09e96eb6c614 Mon Sep 17 00:00:00 2001 From: Ke Date: Thu, 22 Aug 2024 01:04:48 -0700 Subject: [PATCH] Add complex type inputs to array_except, array_intersect and arrays_overlap (#10743) Summary: Pull Request resolved: https://github.com/facebookincubator/velox/pull/10743 Reviewed By: kevinwilfong Differential Revision: D61246379 Pulled By: kewang1024 fbshipit-source-id: 1f2fb793a43e18abe51bfba97ba1558b8204c563 --- .../prestosql/ArrayIntersectExcept.cpp | 125 ++++++++++++++---- .../prestosql/tests/ArrayExceptTest.cpp | 64 +++++++++ .../prestosql/tests/ArrayIntersectTest.cpp | 59 +++++++++ .../prestosql/tests/ArraysOverlapTest.cpp | 65 +++++++++ velox/type/Type.h | 64 +++++++++ 5 files changed, 350 insertions(+), 27 deletions(-) diff --git a/velox/functions/prestosql/ArrayIntersectExcept.cpp b/velox/functions/prestosql/ArrayIntersectExcept.cpp index c13c812914dc..2dfcef7f3fea 100644 --- a/velox/functions/prestosql/ArrayIntersectExcept.cpp +++ b/velox/functions/prestosql/ArrayIntersectExcept.cpp @@ -20,12 +20,23 @@ namespace facebook::velox::functions { namespace { +constexpr vector_size_t kInitialSetSize{128}; + template struct SetWithNull { SetWithNull(vector_size_t initialSetSize = kInitialSetSize) { set.reserve(initialSetSize); } + bool insert(const DecodedVector* decodedElements, vector_size_t offset) { + return set.insert(decodedElements->valueAt(offset)).second; + } + + size_t count(const DecodedVector* decodedElements, vector_size_t offset) + const { + return set.count(decodedElements->valueAt(offset)); + } + void reset() { set.clear(); hasNull = false; @@ -37,7 +48,65 @@ struct SetWithNull { util::floating_point::HashSetNaNAware set; bool hasNull{false}; - static constexpr vector_size_t kInitialSetSize{128}; +}; + +struct ComplexTypeEntry { + const uint64_t hash; + const BaseVector* baseVector; + const vector_size_t index; +}; + +template <> +struct SetWithNull { + struct Hash { + size_t operator()(const ComplexTypeEntry& entry) const { + return entry.hash; + } + }; + + struct EqualTo { + bool operator()(const ComplexTypeEntry& left, const ComplexTypeEntry& right) + const { + return left.baseVector + ->equalValueAt( + right.baseVector, + left.index, + right.index, + CompareFlags::NullHandlingMode::kNullAsValue) + .value(); + } + }; + + folly::F14FastSet set; + bool hasNull{false}; + + SetWithNull(vector_size_t initialSetSize = kInitialSetSize) { + set.reserve(initialSetSize); + } + + bool insert(const DecodedVector* decodedElements, vector_size_t offset) { + const auto vector = decodedElements->base(); + const auto index = decodedElements->index(offset); + const uint64_t hash = vector->hashValueAt(index); + return set.insert(ComplexTypeEntry{hash, vector, index}).second; + } + + size_t count(const DecodedVector* decodedElements, vector_size_t offset) + const { + const auto vector = decodedElements->base(); + const auto index = decodedElements->index(offset); + const uint64_t hash = vector->hashValueAt(index); + return set.count(ComplexTypeEntry{hash, vector, index}); + } + + void reset() { + set.clear(); + hasNull = false; + } + + bool empty() const { + return !hasNull && set.empty(); + } }; // Generates a set based on the elements of an ArrayVector. Note that we take @@ -57,7 +126,7 @@ void generateSet( if (arrayElements->isNullAt(i)) { rightSet.hasNull = true; } else { - rightSet.set.insert(arrayElements->template valueAt(i)); + rightSet.insert(arrayElements, i); } } } @@ -186,19 +255,17 @@ class ArrayIntersectExceptFunction : public exec::VectorFunction { } } } else { - auto val = decodedLeftElements->valueAt(i); // For array_intersect, add the element if it is found (not found // for array_except) in the right-hand side, and wasn't added already // (check outputSet). bool addValue = false; if constexpr (isIntersect) { - addValue = rightSet.set.count(val) > 0; + addValue = rightSet.count(decodedLeftElements, i) > 0; } else { - addValue = rightSet.set.count(val) == 0; + addValue = rightSet.count(decodedLeftElements, i) == 0; } if (addValue) { - auto it = outputSet.set.insert(val); - if (it.second) { + if (outputSet.insert(decodedLeftElements, i)) { rawNewIndices[indicesCursor++] = i; } } @@ -294,7 +361,7 @@ class ArraysOverlapFunction : public exec::VectorFunction { hasNull = true; continue; } - if (rightSet.set.count(decodedLeftElements->valueAt(i)) > 0) { + if (rightSet.count(decodedLeftElements, i) > 0) { // Found an overlapping element. Add to result set. resultBoolVector->set(row, true); return; @@ -396,7 +463,10 @@ SetWithNull validateConstantVectorAndGenerateSet( template std::shared_ptr createTypedArraysIntersectExcept( const std::vector& inputArgs) { - using T = typename TypeTraits::NativeType; + using T = std::conditional_t< + TypeTraits::isPrimitiveType, + typename TypeTraits::NativeType, + ComplexTypeEntry>; VELOX_CHECK_EQ(inputArgs.size(), 2); BaseVector* rhs = inputArgs[1].constantValue.get(); @@ -424,7 +494,7 @@ std::shared_ptr createArrayIntersect( validateMatchingArrayTypes(inputArgs, name, 2); auto elementType = inputArgs.front().type->childAt(0); - return VELOX_DYNAMIC_SCALAR_TEMPLATE_TYPE_DISPATCH( + return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH( createTypedArraysIntersectExcept, /* isIntersect */ true, elementType->kind(), @@ -438,7 +508,7 @@ std::shared_ptr createArrayExcept( validateMatchingArrayTypes(inputArgs, name, 2); auto elementType = inputArgs.front().type->childAt(0); - return VELOX_DYNAMIC_SCALAR_TEMPLATE_TYPE_DISPATCH( + return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH( createTypedArraysIntersectExcept, /* isIntersect */ false, elementType->kind(), @@ -446,18 +516,15 @@ std::shared_ptr createArrayExcept( } std::vector> signatures( - const std::string& returnTypeTemplate) { - std::vector> signatures; - for (const auto& type : exec::primitiveTypeNames()) { - signatures.push_back( - exec::FunctionSignatureBuilder() - .returnType( - fmt::format(fmt::runtime(returnTypeTemplate.c_str()), type)) - .argumentType(fmt::format("array({})", type)) - .argumentType(fmt::format("array({})", type)) - .build()); - } - return signatures; + const std::string& returnType) { + return std::vector>{ + exec::FunctionSignatureBuilder() + .typeVariable("T") + .returnType(returnType) + .argumentType("array(T)") + .argumentType("array(T)") + .build(), + }; } template @@ -466,7 +533,11 @@ const std::shared_ptr createTypedArraysOverlap( VELOX_CHECK_EQ(inputArgs.size(), 2); auto left = inputArgs[0].constantValue.get(); auto right = inputArgs[1].constantValue.get(); - using T = typename TypeTraits::NativeType; + using T = std::conditional_t< + TypeTraits::isPrimitiveType, + typename TypeTraits::NativeType, + ComplexTypeEntry>; + if (left == nullptr && right == nullptr) { return std::make_shared>(); } @@ -484,7 +555,7 @@ std::shared_ptr createArraysOverlapFunction( validateMatchingArrayTypes(inputArgs, name, 2); auto elementType = inputArgs.front().type->childAt(0); - return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( + return VELOX_DYNAMIC_TYPE_DISPATCH( createTypedArraysOverlap, elementType->kind(), inputArgs); } } // namespace @@ -496,11 +567,11 @@ VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( udf_array_intersect, - signatures("array({})"), + signatures("array(T)"), createArrayIntersect); VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( udf_array_except, - signatures("array({})"), + signatures("array(T)"), createArrayExcept); } // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/tests/ArrayExceptTest.cpp b/velox/functions/prestosql/tests/ArrayExceptTest.cpp index aef7de5d2e16..2a23e14788f3 100644 --- a/velox/functions/prestosql/tests/ArrayExceptTest.cpp +++ b/velox/functions/prestosql/tests/ArrayExceptTest.cpp @@ -24,6 +24,9 @@ using namespace facebook::velox::functions::test; namespace { +template +using Pair = std::pair>; + class ArrayExceptTest : public FunctionBaseTest { protected: void testExpr( @@ -278,6 +281,67 @@ TEST_F(ArrayExceptTest, varbinary) { testExpr(expected, "array_except(c0, c1)", {right, left}); } +TEST_F(ArrayExceptTest, complexTypeArray) { + auto left = makeNestedArrayVectorFromJson({ + "[null, [1, 2, 3], [null, null]]", + "[[1], [2], []]", + "[[1, null, 3]]", + "[[1, null, 3]]", + }); + + auto right = makeNestedArrayVectorFromJson({ + "[[1, 2, 3]]", + "[[1]]", + "[[1, null, 3], [1, 2]]", + "[[1, null, 3, null]]", + }); + + auto expected = makeNestedArrayVectorFromJson({ + "[null, [null, null]]", + "[[2], []]", + "[]", + "[[1, null, 3]]", + }); + testExpr(expected, "array_except(c0, c1)", {left, right}); +} + +TEST_F(ArrayExceptTest, complexTypeMap) { + std::vector> a{{"blue", 1}, {"red", 2}}; + std::vector> b{{"blue", 2}, {"red", 2}}; + std::vector> c{{"green", std::nullopt}}; + std::vector> d{{"yellow", 4}, {"purple", 5}}; + std::vector>>> leftData{ + {b, a}, {b}, {c, a}}; + std::vector>>> rightData{ + {a, b}, {}, {a}}; + std::vector>>> expectedData{ + {}, {b}, {c}}; + + auto left = makeArrayOfMapVector(leftData); + auto right = makeArrayOfMapVector(rightData); + auto expected = makeArrayOfMapVector(expectedData); + + testExpr(expected, "array_except(c0, c1)", {left, right}); +} + +TEST_F(ArrayExceptTest, complexTypeRow) { + RowTypePtr rowType = ROW({INTEGER(), VARCHAR()}); + + using ArrayOfRow = std::vector>>; + std::vector leftData = { + {{{1, "red"}}, {{2, "blue"}}, {{3, "green"}}}, + {{{1, "red"}}, {{2, "blue"}}, {}}, + {{{1, "red"}}, std::nullopt, std::nullopt}}; + std::vector rightData = { + {{{2, "blue"}}, {{1, "red"}}}, {{}, {{1, "green"}}}, {{{1, "red"}}}}; + std::vector expectedData = { + {{{3, "green"}}}, {{{1, "red"}}, {{2, "blue"}}}, {std::nullopt}}; + auto left = makeArrayOfRowVector(leftData, rowType); + auto right = makeArrayOfRowVector(rightData, rowType); + auto expected = makeArrayOfRowVector(expectedData, rowType); + testExpr(expected, "array_except(c0, c1)", {left, right}); +} + // When one of the arrays is constant. TEST_F(ArrayExceptTest, constant) { auto array1 = makeNullableArrayVector({ diff --git a/velox/functions/prestosql/tests/ArrayIntersectTest.cpp b/velox/functions/prestosql/tests/ArrayIntersectTest.cpp index 8260f9fbdf9c..3ae43b8acb81 100644 --- a/velox/functions/prestosql/tests/ArrayIntersectTest.cpp +++ b/velox/functions/prestosql/tests/ArrayIntersectTest.cpp @@ -24,6 +24,9 @@ using namespace facebook::velox::functions::test; namespace { +template +using Pair = std::pair>; + class ArrayIntersectTest : public FunctionBaseTest { protected: void testExpr( @@ -257,6 +260,62 @@ TEST_F(ArrayIntersectTest, varbinary) { testExpr(expected, "array_intersect(c0, c1)", {right, left}); } +TEST_F(ArrayIntersectTest, complexTypeArray) { + auto left = makeNestedArrayVectorFromJson({ + "[null, [1, 2, 3], [null, null]]", + "[[1], [2], []]", + "[[1, null, 3]]", + "[[1, null, 3]]", + }); + + auto right = makeNestedArrayVectorFromJson({ + "[[1, 2, 3]]", + "[[1]]", + "[[1, null, 3], [1, 2]]", + "[[1, null, 3, null]]", + }); + + auto expected = makeNestedArrayVectorFromJson( + {"[[1, 2, 3]]", "[[1]]", "[[1, null, 3]]", "[]"}); + testExpr(expected, "array_intersect(c0, c1)", {left, right}); +} + +TEST_F(ArrayIntersectTest, complexTypeMap) { + std::vector> a{{"blue", 1}, {"red", 2}}; + std::vector> b{{"green", std::nullopt}}; + std::vector> c{{"yellow", 4}, {"purple", 5}}; + std::vector>>> leftData{ + {b, a}, {b}, {c, a}}; + std::vector>>> rightData{ + {a, b}, {}, {a}}; + std::vector>>> expectedData{ + {b, a}, {}, {a}}; + + auto left = makeArrayOfMapVector(leftData); + auto right = makeArrayOfMapVector(rightData); + auto expected = makeArrayOfMapVector(expectedData); + + testExpr(expected, "array_intersect(c0, c1)", {left, right}); +} + +TEST_F(ArrayIntersectTest, complexTypeRow) { + RowTypePtr rowType = ROW({INTEGER(), VARCHAR()}); + + using ArrayOfRow = std::vector>>; + std::vector leftData = { + {{{1, "red"}}, {{2, "blue"}}, {{3, "green"}}}, + {{{1, "red"}}, {{2, "blue"}}, {}}, + {{{1, "red"}}, std::nullopt, std::nullopt}}; + std::vector rightData = { + {{{2, "blue"}}, {{1, "red"}}}, {{}, {{1, "green"}}}, {{{1, "red"}}}}; + std::vector expectedData = { + {{{1, "red"}}, {{2, "blue"}}}, {{}}, {{{1, "red"}}}}; + auto left = makeArrayOfRowVector(leftData, rowType); + auto right = makeArrayOfRowVector(rightData, rowType); + auto expected = makeArrayOfRowVector(expectedData, rowType); + testExpr(expected, "array_intersect(c0, c1)", {left, right}); +} + // When one of the arrays is constant. TEST_F(ArrayIntersectTest, constant) { auto array1 = makeNullableArrayVector({ diff --git a/velox/functions/prestosql/tests/ArraysOverlapTest.cpp b/velox/functions/prestosql/tests/ArraysOverlapTest.cpp index 002a6175cfcc..002490be3ab8 100644 --- a/velox/functions/prestosql/tests/ArraysOverlapTest.cpp +++ b/velox/functions/prestosql/tests/ArraysOverlapTest.cpp @@ -23,6 +23,9 @@ using namespace facebook::velox::test; using namespace facebook::velox::functions::test; namespace { +template +using Pair = std::pair>; + class ArraysOverlapTest : public FunctionBaseTest { protected: void testExpr( @@ -210,6 +213,68 @@ TEST_F(ArraysOverlapTest, longStrings) { testExpr(expected, "arrays_overlap(C1, C0)", {array1, array2}); } +TEST_F(ArraysOverlapTest, complexTypeArray) { + auto left = makeNestedArrayVectorFromJson({ + "[null, [1, 2, 3], [null, null]]", + "[[1], [2], []]", + "[[1, null, 3]]", + "[[1, null, 3]]", + "[null]", + }); + + auto right = makeNestedArrayVectorFromJson({ + "[[1, 2, 3]]", + "[[1]]", + "[[1, 2]]", + "[[1, null, 3]]", + "[[]]", + }); + + auto expected = + makeNullableFlatVector({true, true, false, true, std::nullopt}); + testExpr(expected, "arrays_overlap(c0, c1)", {left, right}); +} + +TEST_F(ArraysOverlapTest, complexTypeMap) { + std::vector> a{{"blue", 1}, {"red", 2}}; + std::vector> b{{"green", std::nullopt}}; + std::vector> c{{"yellow", 4}, {"purple", 5}}; + + std::vector>>> leftData{ + {b, a}, {b}, {c, a}}; + std::vector>>> rightData{ + {a, b}, {}, {b}}; + + auto left = makeArrayOfMapVector(leftData); + auto right = makeArrayOfMapVector(rightData); + auto expected = makeNullableFlatVector({true, false, false}); + + testExpr(expected, "arrays_overlap(c0, c1)", {left, right}); +} + +TEST_F(ArraysOverlapTest, complexTypeRow) { + RowTypePtr rowType = ROW({INTEGER(), VARCHAR()}); + + using ArrayOfRow = std::vector>>; + std::vector leftData = { + {{{1, "red"}}, {{2, "blue"}}, {{3, "green"}}}, + {{{1, "red"}}, {{2, "blue"}}, std::nullopt}, + {{{1, "red"}}, std::nullopt, std::nullopt}, + {{{1, "red"}}, {{}}, {{}}}}; + std::vector rightData = { + {{{2, "blue"}}, {{1, "red"}}}, + {{{1, "green"}}}, + {{{1, "red"}}}, + {{{2, "red"}}}}; + + auto left = makeArrayOfRowVector(leftData, rowType); + auto right = makeArrayOfRowVector(rightData, rowType); + auto expected = + makeNullableFlatVector({true, std::nullopt, true, false}); + + testExpr(expected, "arrays_overlap(c0, c1)", {left, right}); +} + //// When one of the arrays is constant. TEST_F(ArraysOverlapTest, constant) { auto array1 = makeNullableArrayVector({ diff --git a/velox/type/Type.h b/velox/type/Type.h index 8a004dec043a..b7bc6bbdbd3d 100644 --- a/velox/type/Type.h +++ b/velox/type/Type.h @@ -1502,6 +1502,70 @@ std::shared_ptr OPAQUE() { } \ }() +#define VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH(TEMPLATE_FUNC, T, typeKind, ...) \ + [&]() { \ + switch (typeKind) { \ + case ::facebook::velox::TypeKind::BOOLEAN: { \ + return TEMPLATE_FUNC( \ + __VA_ARGS__); \ + } \ + case ::facebook::velox::TypeKind::INTEGER: { \ + return TEMPLATE_FUNC( \ + __VA_ARGS__); \ + } \ + case ::facebook::velox::TypeKind::TINYINT: { \ + return TEMPLATE_FUNC( \ + __VA_ARGS__); \ + } \ + case ::facebook::velox::TypeKind::SMALLINT: { \ + return TEMPLATE_FUNC( \ + __VA_ARGS__); \ + } \ + case ::facebook::velox::TypeKind::BIGINT: { \ + return TEMPLATE_FUNC( \ + __VA_ARGS__); \ + } \ + case ::facebook::velox::TypeKind::HUGEINT: { \ + return TEMPLATE_FUNC( \ + __VA_ARGS__); \ + } \ + case ::facebook::velox::TypeKind::REAL: { \ + return TEMPLATE_FUNC( \ + __VA_ARGS__); \ + } \ + case ::facebook::velox::TypeKind::DOUBLE: { \ + return TEMPLATE_FUNC( \ + __VA_ARGS__); \ + } \ + case ::facebook::velox::TypeKind::VARCHAR: { \ + return TEMPLATE_FUNC( \ + __VA_ARGS__); \ + } \ + case ::facebook::velox::TypeKind::VARBINARY: { \ + return TEMPLATE_FUNC( \ + __VA_ARGS__); \ + } \ + case ::facebook::velox::TypeKind::TIMESTAMP: { \ + return TEMPLATE_FUNC( \ + __VA_ARGS__); \ + } \ + case ::facebook::velox::TypeKind::MAP: { \ + return TEMPLATE_FUNC( \ + __VA_ARGS__); \ + } \ + case ::facebook::velox::TypeKind::ARRAY: { \ + return TEMPLATE_FUNC( \ + __VA_ARGS__); \ + } \ + case ::facebook::velox::TypeKind::ROW: { \ + return TEMPLATE_FUNC( \ + __VA_ARGS__); \ + } \ + default: \ + VELOX_FAIL("not a known type kind: {}", mapTypeKindToName(typeKind)); \ + } \ + }() + #define VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH_ALL(TEMPLATE_FUNC, typeKind, ...) \ [&]() { \ if ((typeKind) == ::facebook::velox::TypeKind::UNKNOWN) { \