diff --git a/cpp/src/gandiva/CMakeLists.txt b/cpp/src/gandiva/CMakeLists.txt index 83cec08e71c1d..2f0c46382fd9c 100644 --- a/cpp/src/gandiva/CMakeLists.txt +++ b/cpp/src/gandiva/CMakeLists.txt @@ -86,6 +86,7 @@ set(SRC_FILES literal_holder.cc projector.cc regex_util.cc + replace_holder.cc selection_vector.cc tree_expr_builder.cc to_date_holder.cc @@ -233,6 +234,7 @@ add_gandiva_test(internals-test to_date_holder_test.cc simple_arena_test.cc like_holder_test.cc + replace_holder_test.cc decimal_type_util_test.cc random_generator_holder_test.cc hash_utils_test.cc diff --git a/cpp/src/gandiva/function_holder_registry.h b/cpp/src/gandiva/function_holder_registry.h index 225c73207fcc0..ced1538915dd5 100644 --- a/cpp/src/gandiva/function_holder_registry.h +++ b/cpp/src/gandiva/function_holder_registry.h @@ -28,6 +28,7 @@ #include "gandiva/like_holder.h" #include "gandiva/node.h" #include "gandiva/random_generator_holder.h" +#include "gandiva/replace_holder.h" #include "gandiva/to_date_holder.h" namespace gandiva { @@ -66,6 +67,7 @@ class FunctionHolderRegistry { {"to_date", LAMBDA_MAKER(ToDateHolder)}, {"random", LAMBDA_MAKER(RandomGeneratorHolder)}, {"rand", LAMBDA_MAKER(RandomGeneratorHolder)}, + {"regexp_replace", LAMBDA_MAKER(ReplaceHolder)}, }; return maker_map; } diff --git a/cpp/src/gandiva/function_registry_string.cc b/cpp/src/gandiva/function_registry_string.cc index 9235a3e01a258..b3c99840104c1 100644 --- a/cpp/src/gandiva/function_registry_string.cc +++ b/cpp/src/gandiva/function_registry_string.cc @@ -170,6 +170,12 @@ std::vector GetStringFunctionRegistry() { NativeFunction("rpad", {}, DataTypeVector{utf8(), int32()}, utf8(), kResultNullIfNull, "rpad_utf8_int32", NativeFunction::kNeedsContext), + NativeFunction("regexp_replace", {}, DataTypeVector{utf8(), utf8(), utf8()}, utf8(), + kResultNullIfNull, "gdv_fn_regexp_replace_utf8_utf8", + NativeFunction::kNeedsContext | + NativeFunction::kNeedsFunctionHolder | + NativeFunction::kCanReturnErrors), + NativeFunction("concatOperator", {}, DataTypeVector{utf8(), utf8()}, utf8(), kResultNullIfNull, "concatOperator_utf8_utf8", NativeFunction::kNeedsContext), diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc index 3c278049ed6fb..c98647bb90961 100644 --- a/cpp/src/gandiva/gdv_function_stubs.cc +++ b/cpp/src/gandiva/gdv_function_stubs.cc @@ -33,6 +33,7 @@ #include "gandiva/like_holder.h" #include "gandiva/precompiled/types.h" #include "gandiva/random_generator_holder.h" +#include "gandiva/replace_holder.h" #include "gandiva/to_date_holder.h" /// Stub functions that can be accessed from LLVM or the pre-compiled library. @@ -58,6 +59,18 @@ bool gdv_fn_ilike_utf8_utf8(int64_t ptr, const char* data, int data_len, return (*holder)(std::string(data, data_len)); } +const char* gdv_fn_regexp_replace_utf8_utf8( + int64_t ptr, int64_t holder_ptr, const char* data, int32_t data_len, + const char* /*pattern*/, int32_t /*pattern_len*/, const char* replace_string, + int32_t replace_string_len, int32_t* out_length) { + gandiva::ExecutionContext* context = reinterpret_cast(ptr); + + gandiva::ReplaceHolder* holder = reinterpret_cast(holder_ptr); + + return (*holder)(context, data, data_len, replace_string, replace_string_len, + out_length); +} + double gdv_fn_random(int64_t ptr) { gandiva::RandomGeneratorHolder* holder = reinterpret_cast(ptr); @@ -824,6 +837,21 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const { types->i1_type() /*return_type*/, args, reinterpret_cast(gdv_fn_ilike_utf8_utf8)); + // gdv_fn_regexp_replace_utf8_utf8 + args = {types->i64_type(), // int64_t ptr + types->i64_type(), // int64_t holder_ptr + types->i8_ptr_type(), // const char* data + types->i32_type(), // int data_len + types->i8_ptr_type(), // const char* pattern + types->i32_type(), // int pattern_len + types->i8_ptr_type(), // const char* replace_string + types->i32_type(), // int32_t replace_string_len + types->i32_ptr_type()}; // int32_t* out_length + + engine->AddGlobalMappingForFunc( + "gdv_fn_regexp_replace_utf8_utf8", types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast(gdv_fn_regexp_replace_utf8_utf8)); + // gdv_fn_to_date_utf8_utf8 args = {types->i64_type(), // int64_t execution_context types->i64_type(), // int64_t holder_ptr diff --git a/cpp/src/gandiva/precompiled/string_ops_test.cc b/cpp/src/gandiva/precompiled/string_ops_test.cc index c4854c52db115..660f297a8e2b1 100644 --- a/cpp/src/gandiva/precompiled/string_ops_test.cc +++ b/cpp/src/gandiva/precompiled/string_ops_test.cc @@ -215,7 +215,7 @@ TEST(TestStringOps, TestCastBoolToVarchar) { EXPECT_EQ(std::string(out_str, out_len), "false"); EXPECT_FALSE(ctx.has_error()); - out_str = castVARCHAR_bool_int64(ctx_ptr, true, -3, &out_len); + castVARCHAR_bool_int64(ctx_ptr, true, -3, &out_len); EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Output buffer length can't be negative")); ctx.Reset(); @@ -1400,13 +1400,13 @@ TEST(TestStringOps, TestReplace) { EXPECT_EQ(std::string(out_str, out_len), "TestString"); EXPECT_FALSE(ctx.has_error()); - out_str = replace_with_max_len_utf8_utf8_utf8(ctx_ptr, "Hell", 4, "ell", 3, "ollow", 5, - 5, &out_len); + replace_with_max_len_utf8_utf8_utf8(ctx_ptr, "Hell", 4, "ell", 3, "ollow", 5, 5, + &out_len); EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Buffer overflow for output string")); ctx.Reset(); - out_str = replace_with_max_len_utf8_utf8_utf8(ctx_ptr, "eeee", 4, "e", 1, "aaaa", 4, 14, - &out_len); + replace_with_max_len_utf8_utf8_utf8(ctx_ptr, "eeee", 4, "e", 1, "aaaa", 4, 14, + &out_len); EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Buffer overflow for output string")); ctx.Reset(); } diff --git a/cpp/src/gandiva/replace_holder.cc b/cpp/src/gandiva/replace_holder.cc new file mode 100644 index 0000000000000..8b42b585f9ce2 --- /dev/null +++ b/cpp/src/gandiva/replace_holder.cc @@ -0,0 +1,65 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/replace_holder.h" + +#include "gandiva/node.h" +#include "gandiva/regex_util.h" + +namespace gandiva { + +static bool IsArrowStringLiteral(arrow::Type::type type) { + return type == arrow::Type::STRING || type == arrow::Type::BINARY; +} + +Status ReplaceHolder::Make(const FunctionNode& node, + std::shared_ptr* holder) { + ARROW_RETURN_IF(node.children().size() != 3, + Status::Invalid("'replace' function requires three parameters")); + + auto literal = dynamic_cast(node.children().at(1).get()); + ARROW_RETURN_IF( + literal == nullptr, + Status::Invalid("'replace' function requires a literal as the second parameter")); + + auto literal_type = literal->return_type()->id(); + ARROW_RETURN_IF( + !IsArrowStringLiteral(literal_type), + Status::Invalid( + "'replace' function requires a string literal as the second parameter")); + + return Make(arrow::util::get(literal->holder()), holder); +} + +Status ReplaceHolder::Make(const std::string& sql_pattern, + std::shared_ptr* holder) { + auto lholder = std::shared_ptr(new ReplaceHolder(sql_pattern)); + ARROW_RETURN_IF(!lholder->regex_.ok(), + Status::Invalid("Building RE2 pattern '", sql_pattern, "' failed")); + + *holder = lholder; + return Status::OK(); +} + +void ReplaceHolder::return_error(ExecutionContext* context, std::string& data, + std::string& replace_string) { + std::string err_msg = "Error replacing '" + replace_string + "' on the given string '" + + data + "' for the given pattern: " + pattern_; + context->set_error_msg(err_msg.c_str()); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/replace_holder.h b/cpp/src/gandiva/replace_holder.h new file mode 100644 index 0000000000000..79150d7aa4d57 --- /dev/null +++ b/cpp/src/gandiva/replace_holder.h @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include +#include + +#include "arrow/status.h" +#include "gandiva/execution_context.h" +#include "gandiva/function_holder.h" +#include "gandiva/node.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +/// Function Holder for 'replace' +class GANDIVA_EXPORT ReplaceHolder : public FunctionHolder { + public: + ~ReplaceHolder() override = default; + + static Status Make(const FunctionNode& node, std::shared_ptr* holder); + + static Status Make(const std::string& sql_pattern, + std::shared_ptr* holder); + + /// Return a new string with the pattern that matched the regex replaced for + /// the replace_input parameter. + const char* operator()(ExecutionContext* ctx, const char* user_input, + int32_t user_input_len, const char* replace_input, + int32_t replace_input_len, int32_t* out_length) { + std::string user_input_as_str(user_input, user_input_len); + std::string replace_input_as_str(replace_input, replace_input_len); + + int32_t total_replaces = + RE2::GlobalReplace(&user_input_as_str, regex_, replace_input_as_str); + + if (total_replaces < 0) { + return_error(ctx, user_input_as_str, replace_input_as_str); + *out_length = 0; + return ""; + } + + if (total_replaces == 0) { + *out_length = user_input_len; + return user_input; + } + + *out_length = static_cast(user_input_as_str.size()); + + // This condition treats the case where the whole string is replaced by an empty + // string + if (*out_length == 0) { + return ""; + } + + char* result_buffer = reinterpret_cast(ctx->arena()->Allocate(*out_length)); + + if (result_buffer == NULLPTR) { + ctx->set_error_msg("Could not allocate memory for result"); + *out_length = 0; + return ""; + } + + memcpy(result_buffer, user_input_as_str.data(), *out_length); + + return result_buffer; + } + + private: + explicit ReplaceHolder(const std::string& pattern) + : pattern_(pattern), regex_(pattern) {} + + void return_error(ExecutionContext* context, std::string& data, + std::string& replace_string); + + std::string pattern_; // posix pattern string, to help debugging + RE2 regex_; // compiled regex for the pattern +}; + +} // namespace gandiva diff --git a/cpp/src/gandiva/replace_holder_test.cc b/cpp/src/gandiva/replace_holder_test.cc new file mode 100644 index 0000000000000..b0830d4f00465 --- /dev/null +++ b/cpp/src/gandiva/replace_holder_test.cc @@ -0,0 +1,129 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/replace_holder.h" + +#include + +#include +#include + +namespace gandiva { + +class TestReplaceHolder : public ::testing::Test { + protected: + ExecutionContext execution_context_; +}; + +TEST_F(TestReplaceHolder, TestMultipleReplace) { + std::shared_ptr replace_holder; + + auto status = ReplaceHolder::Make("ana", &replace_holder); + EXPECT_EQ(status.ok(), true) << status.message(); + + std::string input_string = "banana"; + std::string replace_string; + int32_t out_length = 0; + + auto& replace = *replace_holder; + const char* ret = + replace(&execution_context_, input_string.c_str(), + static_cast(input_string.length()), replace_string.c_str(), + static_cast(replace_string.length()), &out_length); + std::string ret_as_str(ret, out_length); + EXPECT_EQ(out_length, 3); + EXPECT_EQ(ret_as_str, "bna"); + + input_string = "bananaana"; + + ret = replace(&execution_context_, input_string.c_str(), + static_cast(input_string.length()), replace_string.c_str(), + static_cast(replace_string.length()), &out_length); + ret_as_str = std::string(ret, out_length); + EXPECT_EQ(out_length, 3); + EXPECT_EQ(ret_as_str, "bna"); + + input_string = "bananana"; + + ret = replace(&execution_context_, input_string.c_str(), + static_cast(input_string.length()), replace_string.c_str(), + static_cast(replace_string.length()), &out_length); + ret_as_str = std::string(ret, out_length); + EXPECT_EQ(out_length, 2); + EXPECT_EQ(ret_as_str, "bn"); + + input_string = "anaana"; + + ret = replace(&execution_context_, input_string.c_str(), + static_cast(input_string.length()), replace_string.c_str(), + static_cast(replace_string.length()), &out_length); + ret_as_str = std::string(ret, out_length); + EXPECT_EQ(out_length, 0); + EXPECT_FALSE(execution_context_.has_error()); + EXPECT_EQ(ret_as_str, ""); +} + +TEST_F(TestReplaceHolder, TestNoMatchPattern) { + std::shared_ptr replace_holder; + + auto status = ReplaceHolder::Make("ana", &replace_holder); + EXPECT_EQ(status.ok(), true) << status.message(); + + std::string input_string = "apple"; + std::string replace_string; + int32_t out_length = 0; + + auto& replace = *replace_holder; + const char* ret = + replace(&execution_context_, input_string.c_str(), + static_cast(input_string.length()), replace_string.c_str(), + static_cast(replace_string.length()), &out_length); + std::string ret_as_string(ret, out_length); + EXPECT_EQ(out_length, 5); + EXPECT_EQ(ret_as_string, "apple"); +} + +TEST_F(TestReplaceHolder, TestReplaceSameSize) { + std::shared_ptr replace_holder; + + auto status = ReplaceHolder::Make("a", &replace_holder); + EXPECT_EQ(status.ok(), true) << status.message(); + + std::string input_string = "ananindeua"; + std::string replace_string = "b"; + int32_t out_length = 0; + + auto& replace = *replace_holder; + const char* ret = + replace(&execution_context_, input_string.c_str(), + static_cast(input_string.length()), replace_string.c_str(), + static_cast(replace_string.length()), &out_length); + std::string ret_as_string(ret, out_length); + EXPECT_EQ(out_length, 10); + EXPECT_EQ(ret_as_string, "bnbnindeub"); +} + +TEST_F(TestReplaceHolder, TestReplaceInvalidPattern) { + std::shared_ptr replace_holder; + + auto status = ReplaceHolder::Make("+", &replace_holder); + EXPECT_EQ(status.ok(), false) << status.message(); + + execution_context_.Reset(); +} + +} // namespace gandiva diff --git a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java index 80d4281f4c2db..734ad87fe1e64 100644 --- a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java +++ b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java @@ -654,6 +654,66 @@ public void testRegex() throws GandivaException { eval.close(); } + @Test + public void testRegexpReplace() throws GandivaException { + + Field x = Field.nullable("x", new ArrowType.Utf8()); + Field replaceString = Field.nullable("replaceString", new ArrowType.Utf8()); + + Field retType = Field.nullable("c", new ArrowType.Utf8()); + + TreeNode cond = + TreeBuilder.makeFunction( + "regexp_replace", + Lists.newArrayList(TreeBuilder.makeField(x), TreeBuilder.makeStringLiteral("ana"), + TreeBuilder.makeField(replaceString)), + new ArrowType.Utf8()); + ExpressionTree expr = TreeBuilder.makeExpression(cond, retType); + Schema schema = new Schema(Lists.newArrayList(x, replaceString)); + Projector eval = Projector.make(schema, Lists.newArrayList(expr)); + + int numRows = 5; + byte[] validity = new byte[]{(byte) 15, 0}; + String[] valuesX = new String[]{"banana", "bananaana", "bananana", "anaana", "anaana"}; + String[] valuesReplace = new String[]{"ue", "", "", "c", ""}; + String[] expected = new String[]{"buena", "bna", "bn", "cc", null}; + + ArrowBuf validityX = buf(validity); + ArrowBuf validityReplace = buf(validity); + List dataBufsX = stringBufs(valuesX); + List dataBufsReplace = stringBufs(valuesReplace); + + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode, fieldNode), + Lists.newArrayList(validityX, dataBufsX.get(0), dataBufsX.get(1), validityReplace, + dataBufsReplace.get(0), dataBufsReplace.get(1))); + + // allocate data for output vector. + VarCharVector outVector = new VarCharVector(EMPTY_SCHEMA_PATH, allocator); + outVector.allocateNew(numRows * 15, numRows); + + // evaluate expression + List output = new ArrayList<>(); + output.add(outVector); + eval.evaluate(batch, output); + eval.close(); + + // match expected output. + for (int i = 0; i < numRows - 1; i++) { + assertFalse("Expect none value equals null", outVector.isNull(i)); + assertEquals(expected[i], new String(outVector.get(i))); + } + + assertTrue("Last value must be null", outVector.isNull(numRows - 1)); + + releaseRecordBatch(batch); + releaseValueVectors(output); + } + @Test public void testRand() throws GandivaException {