From 12ba118cb544432ff949f5b2cdf63177a0f16abd Mon Sep 17 00:00:00 2001 From: Masha Basmanova Date: Tue, 9 Aug 2022 20:20:28 -0700 Subject: [PATCH] Add transform_keys and transform_values Presto functions (#2245) Summary: Pull Request resolved: https://github.com/facebookincubator/velox/pull/2245 Differential Revision: D38565343 Pulled By: mbasmanova fbshipit-source-id: a6a670105d921cf6fd8d88f4d52e112f1a7c00a4 --- velox/docs/functions/map.rst | 23 ++ velox/functions/prestosql/CMakeLists.txt | 2 + velox/functions/prestosql/Transform.cpp | 2 +- velox/functions/prestosql/TransformKeys.cpp | 126 +++++++++ velox/functions/prestosql/TransformValues.cpp | 99 +++++++ .../registration/MapFunctionsRegistration.cpp | 2 + .../functions/prestosql/tests/CMakeLists.txt | 2 + .../prestosql/tests/TransformKeysTest.cpp | 241 ++++++++++++++++++ .../prestosql/tests/TransformValuesTest.cpp | 219 ++++++++++++++++ velox/vector/tests/VectorTestBase.cpp | 4 +- 10 files changed, 718 insertions(+), 2 deletions(-) create mode 100644 velox/functions/prestosql/TransformKeys.cpp create mode 100644 velox/functions/prestosql/TransformValues.cpp create mode 100644 velox/functions/prestosql/tests/TransformKeysTest.cpp create mode 100644 velox/functions/prestosql/tests/TransformValuesTest.cpp diff --git a/velox/docs/functions/map.rst b/velox/docs/functions/map.rst index b6c16dfe026b..4f5f7bd25d6a 100644 --- a/velox/docs/functions/map.rst +++ b/velox/docs/functions/map.rst @@ -59,3 +59,26 @@ Map Functions Corresponds to SQL subscript operator []. SELECT name_to_age_map['Bob'] AS bob_age; + +.. function:: transform_keys(map(K1,V), function(K1,V,K2)) -> map(K2,V) + + Returns a map that applies ``function`` to each entry of ``map`` and transforms the keys:: + + SELECT transform_keys(MAP(ARRAY[], ARRAY[]), (k, v) -> k + 1); -- {} + SELECT transform_keys(MAP(ARRAY [1, 2, 3], ARRAY ['a', 'b', 'c']), (k, v) -> k + 1); -- {2 -> a, 3 -> b, 4 -> c} + SELECT transform_keys(MAP(ARRAY ['a', 'b', 'c'], ARRAY [1, 2, 3]), (k, v) -> v * v); -- {1 -> 1, 4 -> 2, 9 -> 3} + SELECT transform_keys(MAP(ARRAY ['a', 'b'], ARRAY [1, 2]), (k, v) -> k || CAST(v as VARCHAR)); -- {a1 -> 1, b2 -> 2} + SELECT transform_keys(MAP(ARRAY [1, 2], ARRAY [1.0, 1.4]), -- {one -> 1.0, two -> 1.4} + (k, v) -> MAP(ARRAY[1, 2], ARRAY['one', 'two'])[k]); + +.. function:: transform_values(map(K,V1), function(K,V1,V2)) -> map(K,V2) + + Returns a map that applies ``function`` to each entry of ``map`` and transforms the values:: + + SELECT transform_values(MAP(ARRAY[], ARRAY[]), (k, v) -> v + 1); -- {} + SELECT transform_values(MAP(ARRAY [1, 2, 3], ARRAY [10, 20, 30]), (k, v) -> v + k); -- {1 -> 11, 2 -> 22, 3 -> 33} + SELECT transform_values(MAP(ARRAY [1, 2, 3], ARRAY ['a', 'b', 'c']), (k, v) -> k * k); -- {1 -> 1, 2 -> 4, 3 -> 9} + SELECT transform_values(MAP(ARRAY ['a', 'b'], ARRAY [1, 2]), (k, v) -> k || CAST(v as VARCHAR)); -- {a -> a1, b -> b2} + SELECT transform_values(MAP(ARRAY [1, 2], ARRAY [1.0, 1.4]), -- {1 -> one_1.0, 2 -> two_1.4} + (k, v) -> MAP(ARRAY[1, 2], ARRAY['one', 'two'])[k] || '_' || CAST(v AS VARCHAR)); + diff --git a/velox/functions/prestosql/CMakeLists.txt b/velox/functions/prestosql/CMakeLists.txt index 386c652cbef4..572e9842e935 100644 --- a/velox/functions/prestosql/CMakeLists.txt +++ b/velox/functions/prestosql/CMakeLists.txt @@ -44,6 +44,8 @@ add_library( Subscript.cpp ToUtf8.cpp Transform.cpp + TransformKeys.cpp + TransformValues.cpp URLFunctions.cpp VectorArithmetic.cpp WidthBucketArray.cpp diff --git a/velox/functions/prestosql/Transform.cpp b/velox/functions/prestosql/Transform.cpp index c4915f5c260d..e564da6ee452 100644 --- a/velox/functions/prestosql/Transform.cpp +++ b/velox/functions/prestosql/Transform.cpp @@ -26,7 +26,7 @@ class TransformFunction : public exec::VectorFunction { public: bool isDefaultNullBehavior() const override { // transform is null preserving for the array. But since an - // expr tree witht a lambda depends on all named fields, including + // expr tree with a lambda depends on all named fields, including // captures, a null in a capture does not automatically make a // null result. return false; diff --git a/velox/functions/prestosql/TransformKeys.cpp b/velox/functions/prestosql/TransformKeys.cpp new file mode 100644 index 000000000000..6f274acd1e81 --- /dev/null +++ b/velox/functions/prestosql/TransformKeys.cpp @@ -0,0 +1,126 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/expression/Expr.h" +#include "velox/expression/VectorFunction.h" +#include "velox/functions/lib/LambdaFunctionUtil.h" +#include "velox/vector/FunctionVector.h" + +namespace facebook::velox::functions { +namespace { + +// See documentation at https://prestodb.io/docs/current/functions/map.html +class TransformKeysFunction : public exec::VectorFunction { + public: + bool isDefaultNullBehavior() const override { + // transform_keys is null preserving for the map. But + // since an expr tree with a lambda depends on all named fields, including + // captures, a null in a capture does not automatically make a + // null result. + return false; + } + + void apply( + const SelectivityVector& rows, + std::vector& args, + const TypePtr& outputType, + exec::EvalCtx* context, + VectorPtr* result) const override { + VELOX_CHECK_EQ(args.size(), 2); + + // Flatten input map. + exec::LocalDecodedVector mapDecoder(context, *args[0], rows); + auto& decodedMap = *mapDecoder.get(); + + auto flatMap = flattenMap(rows, args[0], decodedMap); + + std::vector lambdaArgs = { + flatMap->mapKeys(), flatMap->mapValues()}; + auto numKeys = flatMap->mapKeys()->size(); + + VectorPtr transformedKeys; + + // Loop over lambda functions and apply these to keys of the map. + // In most cases there will be only one function and the loop will run once. + auto it = args[1]->asUnchecked()->iterator(&rows); + while (auto entry = it.next()) { + auto keyRows = + toElementRows(numKeys, *entry.rows, flatMap.get()); + auto wrapCapture = toWrapCapture( + numKeys, entry.callable, *entry.rows, flatMap); + + entry.callable->apply( + keyRows, wrapCapture, context, lambdaArgs, &transformedKeys); + } + + // TODO Check for duplicates in transformedKeys. + + auto localResult = std::make_shared( + flatMap->pool(), + outputType, + flatMap->nulls(), + flatMap->size(), + flatMap->offsets(), + flatMap->sizes(), + transformedKeys, + flatMap->mapValues()); + + checkDuplicateKeys(localResult, rows); + + context->moveOrCopyResult(localResult, rows, result); + } + + static std::vector> signatures() { + // map(K1, V), function(K1, V) -> K2 -> map(K2, V) + return {exec::FunctionSignatureBuilder() + .typeVariable("K1") + .typeVariable("K2") + .typeVariable("V") + .returnType("map(K2,V)") + .argumentType("map(K1,V)") + .argumentType("function(K1,V,K2)") + .build()}; + } + + private: + void checkDuplicateKeys( + const MapVectorPtr& mapVector, + const SelectivityVector& rows) const { + static const char* kDuplicateKey = "Duplicate map keys are not allowed"; + + MapVector::canonicalize(mapVector); + + auto offsets = mapVector->rawOffsets(); + auto sizes = mapVector->rawSizes(); + auto mapKeys = mapVector->mapKeys(); + rows.applyToSelected([&](auto row) { + auto offset = offsets[row]; + auto size = sizes[row]; + for (auto i = 1; i < size; i++) { + if (mapKeys->equalValueAt(mapKeys.get(), offset + i, offset + i - 1)) { + VELOX_USER_FAIL("{}", kDuplicateKey); + } + } + }); + } +}; +} // namespace + +VELOX_DECLARE_VECTOR_FUNCTION( + udf_transform_keys, + TransformKeysFunction::signatures(), + std::make_unique()); + +} // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/TransformValues.cpp b/velox/functions/prestosql/TransformValues.cpp new file mode 100644 index 000000000000..ded2df6bc115 --- /dev/null +++ b/velox/functions/prestosql/TransformValues.cpp @@ -0,0 +1,99 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/expression/Expr.h" +#include "velox/expression/VectorFunction.h" +#include "velox/functions/lib/LambdaFunctionUtil.h" +#include "velox/vector/FunctionVector.h" + +namespace facebook::velox::functions { +namespace { + +// See documentation at https://prestodb.io/docs/current/functions/map.html +class TransformValuesFunction : public exec::VectorFunction { + public: + bool isDefaultNullBehavior() const override { + // transform_values is null preserving for the map. But + // since an expr tree with a lambda depends on all named fields, including + // captures, a null in a capture does not automatically make a + // null result. + return false; + } + + void apply( + const SelectivityVector& rows, + std::vector& args, + const TypePtr& outputType, + exec::EvalCtx* context, + VectorPtr* result) const override { + VELOX_CHECK_EQ(args.size(), 2); + + // Flatten input map. + exec::LocalDecodedVector mapDecoder(context, *args[0], rows); + auto& decodedMap = *mapDecoder.get(); + + auto flatMap = flattenMap(rows, args[0], decodedMap); + + std::vector lambdaArgs = { + flatMap->mapKeys(), flatMap->mapValues()}; + auto numValues = flatMap->mapValues()->size(); + + VectorPtr transformedValues; + + // Loop over lambda functions and apply these to keys of the map. + // In most cases there will be only one function and the loop will run once. + auto it = args[1]->asUnchecked()->iterator(&rows); + while (auto entry = it.next()) { + auto keyRows = + toElementRows(numValues, *entry.rows, flatMap.get()); + auto wrapCapture = toWrapCapture( + numValues, entry.callable, *entry.rows, flatMap); + + entry.callable->apply( + keyRows, wrapCapture, context, lambdaArgs, &transformedValues); + } + + auto localResult = std::make_shared( + flatMap->pool(), + outputType, + flatMap->nulls(), + flatMap->size(), + flatMap->offsets(), + flatMap->sizes(), + flatMap->mapKeys(), + transformedValues); + context->moveOrCopyResult(localResult, rows, result); + } + + static std::vector> signatures() { + // map(K, V1), function(K, V1) -> V2 -> map(K, V2) + return {exec::FunctionSignatureBuilder() + .typeVariable("K") + .typeVariable("V1") + .typeVariable("V2") + .returnType("map(K,V2)") + .argumentType("map(K,V1)") + .argumentType("function(K,V1,V2)") + .build()}; + } +}; +} // namespace + +VELOX_DECLARE_VECTOR_FUNCTION( + udf_transform_values, + TransformValuesFunction::signatures(), + std::make_unique()); + +} // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/registration/MapFunctionsRegistration.cpp b/velox/functions/prestosql/registration/MapFunctionsRegistration.cpp index d25cc291cf70..104e28956efb 100644 --- a/velox/functions/prestosql/registration/MapFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/MapFunctionsRegistration.cpp @@ -20,6 +20,8 @@ namespace facebook::velox::functions { void registerMapFunctions() { VELOX_REGISTER_VECTOR_FUNCTION(udf_map_filter, "map_filter"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_transform_keys, "transform_keys"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_transform_values, "transform_values"); VELOX_REGISTER_VECTOR_FUNCTION(udf_map, "map"); VELOX_REGISTER_VECTOR_FUNCTION(udf_map_concat, "map_concat"); VELOX_REGISTER_VECTOR_FUNCTION(udf_map_entries, "map_entries"); diff --git a/velox/functions/prestosql/tests/CMakeLists.txt b/velox/functions/prestosql/tests/CMakeLists.txt index 5517d9336aec..440affb3f0d1 100644 --- a/velox/functions/prestosql/tests/CMakeLists.txt +++ b/velox/functions/prestosql/tests/CMakeLists.txt @@ -59,6 +59,8 @@ add_executable( SplitTest.cpp StringFunctionsTest.cpp TransformTest.cpp + TransformKeysTest.cpp + TransformValuesTest.cpp URLFunctionsTest.cpp WidthBucketArrayTest.cpp GreatestLeastTest.cpp diff --git a/velox/functions/prestosql/tests/TransformKeysTest.cpp b/velox/functions/prestosql/tests/TransformKeysTest.cpp new file mode 100644 index 000000000000..7eeb991d4b87 --- /dev/null +++ b/velox/functions/prestosql/tests/TransformKeysTest.cpp @@ -0,0 +1,241 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/prestosql/tests/FunctionBaseTest.h" + +using namespace facebook::velox; +using namespace facebook::velox::test; + +class TransformKeysTest : public functions::test::FunctionBaseTest {}; + +TEST_F(TransformKeysTest, basic) { + vector_size_t size = 1'000; + auto input = makeRowVector({ + makeMapVector( + size, + [](auto row) { return row % 5; }, + [](auto row) { return row % 7; }, + [](auto row) { return row % 11; }, + nullEvery(13)), + }); + registerLambda( + "plus5", + rowType("x", BIGINT(), "unused", INTEGER()), + input->type(), + "x + 5"); + + auto result = + evaluate("transform_keys(c0, function('plus5'))", input); + + auto expectedResult = makeMapVector( + size, + [](auto row) { return row % 5; }, + [](auto row) { return row % 7 + 5; }, + [](auto row) { return row % 11; }, + nullEvery(13)); + assertEqualVectors(expectedResult, result); + + registerLambda( + "key+value", + rowType("k", BIGINT(), "v", INTEGER()), + input->type(), + "k + v"); + + result = + evaluate("transform_keys(c0, function('key+value'))", input); + + expectedResult = makeMapVector( + size, + [](auto row) { return row % 5; }, + [](auto row) { return row % 7 + row % 11; }, + [](auto row) { return row % 11; }, + nullEvery(13)); + assertEqualVectors(expectedResult, result); +} + +TEST_F(TransformKeysTest, duplicateKeys) { + vector_size_t size = 1'000; + auto input = makeRowVector({ + makeMapVector( + size, + [](auto row) { return row % 5; }, + [](auto row) { return row % 7; }, + [](auto row) { return row % 11; }, + nullEvery(13)), + }); + registerLambda( + "mod2", + rowType("x", BIGINT(), "unused", INTEGER()), + input->type(), + "x % 2"); + + VELOX_ASSERT_THROW( + evaluate("transform_keys(c0, function('mod2'))", input), + "Duplicate map keys are not allowed"); +} + +TEST_F(TransformKeysTest, differentResultType) { + vector_size_t size = 1'000; + auto input = makeRowVector({ + makeMapVector( + size, + [](auto row) { return row % 5; }, + [](auto row) { return row % 7; }, + [](auto row) { return row % 11; }, + nullEvery(13)), + }); + registerLambda( + "oneTenth", + rowType("x", BIGINT(), "unused", INTEGER()), + input->type(), + "x::double * 0.1"); + + auto result = + evaluate("transform_keys(c0, function('oneTenth'))", input); + + auto expectedResult = makeMapVector( + size, + [](auto row) { return row % 5; }, + [](auto row) { return (row % 7) * 0.1; }, + [](auto row) { return row % 11; }, + nullEvery(13)); + assertEqualVectors(expectedResult, result); +} + +// Test different lambdas applied to different rows. +TEST_F(TransformKeysTest, conditional) { + vector_size_t size = 1'000; + + // Make 2 columns: the map to transform and a boolean that decided which + // lambda to use. + auto inputMap = makeMapVector( + size, + [](auto row) { return row % 5; }, + [](auto row) { return row % 7; }, + [](auto row) { return row % 11; }, + nullEvery(13)); + auto condition = + makeFlatVector(size, [](auto row) { return row % 3 == 1; }); + auto input = makeRowVector({condition, inputMap}); + auto signature = rowType("x", BIGINT(), "unused", INTEGER()); + registerLambda("plus5", signature, input->type(), "x + 5"); + registerLambda("minus3", signature, input->type(), "x - 3"); + + auto result = evaluate( + "transform_keys(c1, if (c0, function('plus5'), function('minus3')))", + input); + + // Make 2 expected vectors: one for rows where condition is true and another + // for rows where condition is false. + auto expectedPlus5 = makeMapVector( + size, + [](auto row) { return row % 5; }, + [](auto row) { return row % 7 + 5; }, + [](auto row) { return row % 11; }, + nullEvery(13)); + auto expectedMinus3 = makeMapVector( + size, + [](auto row) { return row % 5; }, + [](auto row) { return row % 7 - 3; }, + [](auto row) { return row % 11; }, + nullEvery(13)); + ASSERT_EQ(size, result->size()); + for (auto i = 0; i < size; i++) { + if (i % 3 == 1) { + ASSERT_TRUE(expectedPlus5->equalValueAt(result.get(), i, i)) + << "at " << i << ": " << expectedPlus5->toString(i) << " vs. " + << result->toString(i); + + } else { + ASSERT_TRUE(expectedMinus3->equalValueAt(result.get(), i, i)) + << "at " << i << ": " << expectedMinus3->toString(i) << " vs. " + << result->toString(i); + } + } +} + +TEST_F(TransformKeysTest, dictionaryWithUniqueValues) { + vector_size_t size = 1'000; + + auto indices = makeIndicesInReverse(size); + auto input = makeRowVector( + {makeFlatVector(size, [](auto /* row */) { return 5; }), + wrapInDictionary( + indices, + size, + makeMapVector( + size, + [](auto row) { return row % 5; }, + [](auto row) { return row % 7; }, + [](auto row) { return row % 11; }, + nullEvery(13)))}); + + registerLambda( + "plus5", + rowType("x", BIGINT(), "unused", INTEGER()), + input->type(), + "x + c0"); + + auto result = + evaluate("transform_keys(c1, function('plus5'))", input); + + auto expectedResult = wrapInDictionary( + indices, + size, + makeMapVector( + size, + [](auto row) { return row % 5; }, + [](auto row) { return row % 7 + 5; }, + [](auto row) { return row % 11; }, + nullEvery(13))); + assertEqualVectors(expectedResult, result); +} + +TEST_F(TransformKeysTest, dictionaryWithDuplicates) { + vector_size_t size = 1'000; + + // Make a map vector where each row repeats twice. + BufferPtr indices = makeIndices(size, [](auto row) { return row / 2; }); + auto inputMap = wrapInDictionary( + indices, + size, + makeMapVector( + size, + [](auto row) { return row % 5; }, + [](auto row) { return row % 7; }, + [](auto row) { return row % 11; }, + nullEvery(13))); + + // Make a capture with unique values. + auto capture = makeFlatVector(size, [](auto row) { return row; }); + + auto input = makeRowVector({capture, inputMap}); + + registerLambda( + "x+c0", + rowType("x", BIGINT(), "unused", INTEGER()), + input->type(), + "x + c0"); + + auto result = + evaluate("transform_keys(c1, function('x+c0'))", input); + + auto expectedResult = evaluate( + "transform_keys(c1, function('x+c0'))", + makeRowVector({capture, flatten(inputMap)})); + + assertEqualVectors(expectedResult, result); +} diff --git a/velox/functions/prestosql/tests/TransformValuesTest.cpp b/velox/functions/prestosql/tests/TransformValuesTest.cpp new file mode 100644 index 000000000000..ede226797565 --- /dev/null +++ b/velox/functions/prestosql/tests/TransformValuesTest.cpp @@ -0,0 +1,219 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/functions/prestosql/tests/FunctionBaseTest.h" + +using namespace facebook::velox; +using namespace facebook::velox::test; + +class TransformValuesTest : public functions::test::FunctionBaseTest {}; + +TEST_F(TransformValuesTest, basic) { + vector_size_t size = 1'000; + auto input = makeRowVector({ + makeMapVector( + size, + [](auto row) { return row % 5; }, + [](auto row) { return row % 7; }, + [](auto row) { return row % 11; }, + nullEvery(13)), + }); + registerLambda( + "plus5", + rowType("unused_k", INTEGER(), "v", BIGINT()), + input->type(), + "v + 5"); + + auto result = + evaluate("transform_values(c0, function('plus5'))", input); + + auto expectedResult = makeMapVector( + size, + [](auto row) { return row % 5; }, + [](auto row) { return row % 7; }, + [](auto row) { return row % 11 + 5; }, + nullEvery(13)); + assertEqualVectors(expectedResult, result); + + registerLambda( + "key+value", + rowType("k", INTEGER(), "v", BIGINT()), + input->type(), + "k + v"); + + result = + evaluate("transform_values(c0, function('key+value'))", input); + + expectedResult = makeMapVector( + size, + [](auto row) { return row % 5; }, + [](auto row) { return row % 7; }, + [](auto row) { return row % 7 + row % 11; }, + nullEvery(13)); + assertEqualVectors(expectedResult, result); +} + +TEST_F(TransformValuesTest, differentResultType) { + vector_size_t size = 1'000; + auto input = makeRowVector({ + makeMapVector( + size, + [](auto row) { return row % 5; }, + [](auto row) { return row % 7; }, + [](auto row) { return row % 11; }, + nullEvery(13)), + }); + registerLambda( + "gt3", + rowType("unused_k", INTEGER(), "v", BIGINT()), + input->type(), + "v > 3"); + + auto result = + evaluate("transform_values(c0, function('gt3'))", input); + + auto expectedResult = makeMapVector( + size, + [](auto row) { return row % 5; }, + [](auto row) { return row % 7; }, + [](auto row) { return row % 11 > 3; }, + nullEvery(13)); + assertEqualVectors(expectedResult, result); +} + +// Test different lambdas applied to different rows. +TEST_F(TransformValuesTest, conditional) { + vector_size_t size = 1'000; + + // Make 2 columns: the map to transform and a boolean that decided which + // lambda to use. + auto inputMap = makeMapVector( + size, + [](auto row) { return row % 5; }, + [](auto row) { return row % 7; }, + [](auto row) { return row % 11; }, + nullEvery(13)); + auto condition = + makeFlatVector(size, [](auto row) { return row % 3 == 1; }); + auto input = makeRowVector({condition, inputMap}); + auto signature = rowType("unused_k", INTEGER(), "v", BIGINT()); + registerLambda("plus5", signature, input->type(), "v + 5"); + registerLambda("minus3", signature, input->type(), "v - 3"); + + auto result = evaluate( + "transform_values(c1, if (c0, function('plus5'), function('minus3')))", + input); + + // Make 2 expected vectors: one for rows where condition is true and another + // for rows where condition is false. + auto expectedPlus5 = makeMapVector( + size, + [](auto row) { return row % 5; }, + [](auto row) { return row % 7; }, + [](auto row) { return row % 11 + 5; }, + nullEvery(13)); + auto expectedMinus3 = makeMapVector( + size, + [](auto row) { return row % 5; }, + [](auto row) { return row % 7; }, + [](auto row) { return row % 11 - 3; }, + nullEvery(13)); + ASSERT_EQ(size, result->size()); + for (auto i = 0; i < size; i++) { + if (i % 3 == 1) { + ASSERT_TRUE(expectedPlus5->equalValueAt(result.get(), i, i)) + << "at " << i << ": " << expectedPlus5->toString(i) << " vs. " + << result->toString(i); + + } else { + ASSERT_TRUE(expectedMinus3->equalValueAt(result.get(), i, i)) + << "at " << i << ": " << expectedMinus3->toString(i) << " vs. " + << result->toString(i); + } + } +} + +TEST_F(TransformValuesTest, dictionaryWithUniqueValues) { + vector_size_t size = 1'000; + + auto indices = makeIndicesInReverse(size); + auto input = makeRowVector( + {makeFlatVector(size, [](auto /* row */) { return 5; }), + wrapInDictionary( + indices, + size, + makeMapVector( + size, + [](auto row) { return row % 5; }, + [](auto row) { return row % 7; }, + [](auto row) { return row % 11; }, + nullEvery(13)))}); + + registerLambda( + "plus5", + rowType("unused_k", INTEGER(), "v", BIGINT()), + input->type(), + "v + c0"); + + auto result = + evaluate("transform_values(c1, function('plus5'))", input); + + auto expectedResult = wrapInDictionary( + indices, + size, + makeMapVector( + size, + [](auto row) { return row % 5; }, + [](auto row) { return row % 7; }, + [](auto row) { return row % 11 + 5; }, + nullEvery(13))); + assertEqualVectors(expectedResult, result); +} + +TEST_F(TransformValuesTest, dictionaryWithDuplicates) { + vector_size_t size = 1'000; + + // Make a map vector where each row repeats twice. + BufferPtr indices = makeIndices(size, [](auto row) { return row / 2; }); + auto inputMap = wrapInDictionary( + indices, + size, + makeMapVector( + size, + [](auto row) { return row % 5; }, + [](auto row) { return row % 7; }, + [](auto row) { return row % 11; }, + nullEvery(13))); + + // Make a capture with unique values. + auto capture = makeFlatVector(size, [](auto row) { return row; }); + + auto input = makeRowVector({capture, inputMap}); + + registerLambda( + "v+c0", + rowType("unused_k", INTEGER(), "v", BIGINT()), + input->type(), + "v + c0"); + + auto result = + evaluate("transform_values(c1, function('v+c0'))", input); + + auto expectedResult = evaluate( + "transform_values(c1, function('v+c0'))", + makeRowVector({capture, flatten(inputMap)})); + + assertEqualVectors(expectedResult, result); +} diff --git a/velox/vector/tests/VectorTestBase.cpp b/velox/vector/tests/VectorTestBase.cpp index dff6f6102fbf..1a7b1613f125 100644 --- a/velox/vector/tests/VectorTestBase.cpp +++ b/velox/vector/tests/VectorTestBase.cpp @@ -86,7 +86,9 @@ void assertEqualVectors( const VectorPtr& actual, const std::string& additionalContext) { ASSERT_EQ(expected->size(), actual->size()) << additionalContext; - ASSERT_TRUE(expected->type()->equivalent(*actual->type())); + ASSERT_TRUE(expected->type()->equivalent(*actual->type())) + << "Expected " << expected->type()->toString() << ", but got " + << actual->type()->toString(); for (auto i = 0; i < expected->size(); i++) { ASSERT_TRUE(expected->equalValueAt(actual.get(), i, i)) << "at " << i << ": expected " << expected->toString(i) << ", but got "