Skip to content

Commit

Permalink
Add implementation for REGEXP_REPLACE
Browse files Browse the repository at this point in the history
  • Loading branch information
rodrigojdebem authored and anthonylouisbsb committed Jul 13, 2021
1 parent 090e2cf commit baf2778
Show file tree
Hide file tree
Showing 9 changed files with 394 additions and 5 deletions.
2 changes: 2 additions & 0 deletions cpp/src/gandiva/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ set(SRC_FILES
literal_holder.cc
projector.cc
regex_util.cc
replace_holder.cc
selection_vector.cc
tree_expr_builder.cc
to_date_holder.cc
Expand Down Expand Up @@ -233,6 +234,7 @@ add_gandiva_test(internals-test
to_date_holder_test.cc
simple_arena_test.cc
like_holder_test.cc
replace_holder_test.cc
decimal_type_util_test.cc
random_generator_holder_test.cc
hash_utils_test.cc
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/gandiva/function_holder_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "gandiva/like_holder.h"
#include "gandiva/node.h"
#include "gandiva/random_generator_holder.h"
#include "gandiva/replace_holder.h"
#include "gandiva/to_date_holder.h"

namespace gandiva {
Expand Down Expand Up @@ -66,6 +67,7 @@ class FunctionHolderRegistry {
{"to_date", LAMBDA_MAKER(ToDateHolder)},
{"random", LAMBDA_MAKER(RandomGeneratorHolder)},
{"rand", LAMBDA_MAKER(RandomGeneratorHolder)},
{"regexp_replace", LAMBDA_MAKER(ReplaceHolder)},
};
return maker_map;
}
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/gandiva/function_registry_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,12 @@ std::vector<NativeFunction> GetStringFunctionRegistry() {
NativeFunction("rpad", {}, DataTypeVector{utf8(), int32()}, utf8(),
kResultNullIfNull, "rpad_utf8_int32", NativeFunction::kNeedsContext),

NativeFunction("regexp_replace", {}, DataTypeVector{utf8(), utf8(), utf8()}, utf8(),
kResultNullIfNull, "gdv_fn_regexp_replace_utf8_utf8",
NativeFunction::kNeedsContext |
NativeFunction::kNeedsFunctionHolder |
NativeFunction::kCanReturnErrors),

NativeFunction("concatOperator", {}, DataTypeVector{utf8(), utf8()}, utf8(),
kResultNullIfNull, "concatOperator_utf8_utf8",
NativeFunction::kNeedsContext),
Expand Down
28 changes: 28 additions & 0 deletions cpp/src/gandiva/gdv_function_stubs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "gandiva/like_holder.h"
#include "gandiva/precompiled/types.h"
#include "gandiva/random_generator_holder.h"
#include "gandiva/replace_holder.h"
#include "gandiva/to_date_holder.h"

/// Stub functions that can be accessed from LLVM or the pre-compiled library.
Expand All @@ -58,6 +59,18 @@ bool gdv_fn_ilike_utf8_utf8(int64_t ptr, const char* data, int data_len,
return (*holder)(std::string(data, data_len));
}

const char* gdv_fn_regexp_replace_utf8_utf8(
int64_t ptr, int64_t holder_ptr, const char* data, int32_t data_len,
const char* /*pattern*/, int32_t /*pattern_len*/, const char* replace_string,
int32_t replace_string_len, int32_t* out_length) {
gandiva::ExecutionContext* context = reinterpret_cast<gandiva::ExecutionContext*>(ptr);

gandiva::ReplaceHolder* holder = reinterpret_cast<gandiva::ReplaceHolder*>(holder_ptr);

return (*holder)(context, data, data_len, replace_string, replace_string_len,
out_length);
}

double gdv_fn_random(int64_t ptr) {
gandiva::RandomGeneratorHolder* holder =
reinterpret_cast<gandiva::RandomGeneratorHolder*>(ptr);
Expand Down Expand Up @@ -824,6 +837,21 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const {
types->i1_type() /*return_type*/, args,
reinterpret_cast<void*>(gdv_fn_ilike_utf8_utf8));

// gdv_fn_regexp_replace_utf8_utf8
args = {types->i64_type(), // int64_t ptr
types->i64_type(), // int64_t holder_ptr
types->i8_ptr_type(), // const char* data
types->i32_type(), // int data_len
types->i8_ptr_type(), // const char* pattern
types->i32_type(), // int pattern_len
types->i8_ptr_type(), // const char* replace_string
types->i32_type(), // int32_t replace_string_len
types->i32_ptr_type()}; // int32_t* out_length

engine->AddGlobalMappingForFunc(
"gdv_fn_regexp_replace_utf8_utf8", types->i8_ptr_type() /*return_type*/, args,
reinterpret_cast<void*>(gdv_fn_regexp_replace_utf8_utf8));

// gdv_fn_to_date_utf8_utf8
args = {types->i64_type(), // int64_t execution_context
types->i64_type(), // int64_t holder_ptr
Expand Down
10 changes: 5 additions & 5 deletions cpp/src/gandiva/precompiled/string_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ TEST(TestStringOps, TestCastBoolToVarchar) {
EXPECT_EQ(std::string(out_str, out_len), "false");
EXPECT_FALSE(ctx.has_error());

out_str = castVARCHAR_bool_int64(ctx_ptr, true, -3, &out_len);
castVARCHAR_bool_int64(ctx_ptr, true, -3, &out_len);
EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr("Output buffer length can't be negative"));
ctx.Reset();
Expand Down Expand Up @@ -1400,13 +1400,13 @@ TEST(TestStringOps, TestReplace) {
EXPECT_EQ(std::string(out_str, out_len), "TestString");
EXPECT_FALSE(ctx.has_error());

out_str = replace_with_max_len_utf8_utf8_utf8(ctx_ptr, "Hell", 4, "ell", 3, "ollow", 5,
5, &out_len);
replace_with_max_len_utf8_utf8_utf8(ctx_ptr, "Hell", 4, "ell", 3, "ollow", 5, 5,
&out_len);
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Buffer overflow for output string"));
ctx.Reset();

out_str = replace_with_max_len_utf8_utf8_utf8(ctx_ptr, "eeee", 4, "e", 1, "aaaa", 4, 14,
&out_len);
replace_with_max_len_utf8_utf8_utf8(ctx_ptr, "eeee", 4, "e", 1, "aaaa", 4, 14,
&out_len);
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Buffer overflow for output string"));
ctx.Reset();
}
Expand Down
65 changes: 65 additions & 0 deletions cpp/src/gandiva/replace_holder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "gandiva/replace_holder.h"

#include "gandiva/node.h"
#include "gandiva/regex_util.h"

namespace gandiva {

static bool IsArrowStringLiteral(arrow::Type::type type) {
return type == arrow::Type::STRING || type == arrow::Type::BINARY;
}

Status ReplaceHolder::Make(const FunctionNode& node,
std::shared_ptr<ReplaceHolder>* holder) {
ARROW_RETURN_IF(node.children().size() != 3,
Status::Invalid("'replace' function requires three parameters"));

auto literal = dynamic_cast<LiteralNode*>(node.children().at(1).get());
ARROW_RETURN_IF(
literal == nullptr,
Status::Invalid("'replace' function requires a literal as the second parameter"));

auto literal_type = literal->return_type()->id();
ARROW_RETURN_IF(
!IsArrowStringLiteral(literal_type),
Status::Invalid(
"'replace' function requires a string literal as the second parameter"));

return Make(arrow::util::get<std::string>(literal->holder()), holder);
}

Status ReplaceHolder::Make(const std::string& sql_pattern,
std::shared_ptr<ReplaceHolder>* holder) {
auto lholder = std::shared_ptr<ReplaceHolder>(new ReplaceHolder(sql_pattern));
ARROW_RETURN_IF(!lholder->regex_.ok(),
Status::Invalid("Building RE2 pattern '", sql_pattern, "' failed"));

*holder = lholder;
return Status::OK();
}

void ReplaceHolder::return_error(ExecutionContext* context, std::string& data,
std::string& replace_string) {
std::string err_msg = "Error replacing '" + replace_string + "' on the given string '" +
data + "' for the given pattern: " + pattern_;
context->set_error_msg(err_msg.c_str());
}

} // namespace gandiva
97 changes: 97 additions & 0 deletions cpp/src/gandiva/replace_holder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#pragma once

#include <re2/re2.h>

#include <memory>
#include <string>

#include "arrow/status.h"
#include "gandiva/execution_context.h"
#include "gandiva/function_holder.h"
#include "gandiva/node.h"
#include "gandiva/visibility.h"

namespace gandiva {

/// Function Holder for 'replace'
class GANDIVA_EXPORT ReplaceHolder : public FunctionHolder {
public:
~ReplaceHolder() override = default;

static Status Make(const FunctionNode& node, std::shared_ptr<ReplaceHolder>* holder);

static Status Make(const std::string& sql_pattern,
std::shared_ptr<ReplaceHolder>* holder);

/// Return a new string with the pattern that matched the regex replaced for
/// the replace_input parameter.
const char* operator()(ExecutionContext* ctx, const char* user_input,
int32_t user_input_len, const char* replace_input,
int32_t replace_input_len, int32_t* out_length) {
std::string user_input_as_str(user_input, user_input_len);
std::string replace_input_as_str(replace_input, replace_input_len);

int32_t total_replaces =
RE2::GlobalReplace(&user_input_as_str, regex_, replace_input_as_str);

if (total_replaces < 0) {
return_error(ctx, user_input_as_str, replace_input_as_str);
*out_length = 0;
return "";
}

if (total_replaces == 0) {
*out_length = user_input_len;
return user_input;
}

*out_length = static_cast<int32_t>(user_input_as_str.size());

// This condition treats the case where the whole string is replaced by an empty
// string
if (*out_length == 0) {
return "";
}

char* result_buffer = reinterpret_cast<char*>(ctx->arena()->Allocate(*out_length));

if (result_buffer == NULLPTR) {
ctx->set_error_msg("Could not allocate memory for result");
*out_length = 0;
return "";
}

memcpy(result_buffer, user_input_as_str.data(), *out_length);

return result_buffer;
}

private:
explicit ReplaceHolder(const std::string& pattern)
: pattern_(pattern), regex_(pattern) {}

void return_error(ExecutionContext* context, std::string& data,
std::string& replace_string);

std::string pattern_; // posix pattern string, to help debugging
RE2 regex_; // compiled regex for the pattern
};

} // namespace gandiva
Loading

0 comments on commit baf2778

Please sign in to comment.