Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support utf-8 encoding in instr, locate, locate_pos, lpad, rpad #3638

Merged
merged 4 commits into from
May 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 90 additions & 21 deletions be/src/exprs/string_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ size_t get_utf8_byte_length(unsigned char byte) {
}
return char_size;
}
size_t get_char_len(const StringVal& str, std::vector<size_t>* str_index) {
size_t char_len = 0;
for (size_t i = 0, char_size = 0; i < str.len; i += char_size) {
yangzhg marked this conversation as resolved.
Show resolved Hide resolved
char_size = get_utf8_byte_length((unsigned)(str.ptr)[i]);
str_index->push_back(i);
++char_len;
}
return char_len;
}

// This behaves identically to the mysql implementation, namely:
// - 1-indexed positions
Expand All @@ -73,8 +82,7 @@ StringVal StringFunctions::substring(
std::vector<size_t> index;
for (size_t i = 0, char_size = 0; i < str.len; i += char_size) {
char_size = get_utf8_byte_length((unsigned)(str.ptr)[i]);
index.push_back(byte_pos);
byte_pos += char_size;
index.push_back(i);
if (pos.val > 0 && index.size() > pos.val + len.val) {
break;
}
Expand Down Expand Up @@ -196,28 +204,45 @@ StringVal StringFunctions::lpad(
if (str.is_null || len.is_null || pad.is_null || len.val < 0) {
return StringVal::null();
}

std::vector<size_t> str_index;
size_t str_char_size = get_char_len(str, &str_index);
std::vector<size_t> pad_index;
size_t pad_char_size = get_char_len(pad, &pad_index);

// Corner cases: Shrink the original string, or leave it alone.
// TODO: Hive seems to go into an infinite loop if pad.len == 0,
// so we should pay attention to Hive's future solution to be compatible.
if (len.val <= str.len || pad.len == 0) {
return StringVal(str.ptr, len.val);
if (len.val <= str_char_size || pad.len == 0) {
if (len.val > str_index.size()) {
return StringVal::null();
}
if (len.val == str_index.size()) {
return StringVal(str.ptr, len.val);
}
return StringVal(str.ptr, str_index[len.val]);
}

// TODO pengyubing
// StringVal result = StringVal::create_temp_string_val(context, len.val);
StringVal result(context, len.val);
int32_t pad_byte_len = 0;
int32_t pad_times = (len.val - str_char_size) / pad_char_size;
int32_t pad_remainder = (len.val - str_char_size) % pad_char_size;
pad_byte_len = pad_times * pad.len;
pad_byte_len += pad_index[pad_remainder];
int32_t byte_len = str.len + pad_byte_len;
StringVal result(context, byte_len);
if (result.is_null) {
return result;
}
int padded_prefix_len = len.val - str.len;
int pad_index = 0;
int pad_idx = 0;
int result_index = 0;
uint8_t* ptr = result.ptr;

// Prepend chars of pad.
while (result_index < padded_prefix_len) {
ptr[result_index++] = pad.ptr[pad_index++];
pad_index = pad_index % pad.len;
while (result_index < pad_byte_len) {
ptr[result_index++] = pad.ptr[pad_idx++];
pad_idx = pad_idx % pad.len;
}

// Append given string.
Expand All @@ -231,28 +256,46 @@ StringVal StringFunctions::rpad(
if (str.is_null || len.is_null || pad.is_null || len.val < 0) {
return StringVal::null();
}

std::vector<size_t> str_index;
size_t str_char_size = get_char_len(str, &str_index);
std::vector<size_t> pad_index;
size_t pad_char_size = get_char_len(pad, &pad_index);

// Corner cases: Shrink the original string, or leave it alone.
// TODO: Hive seems to go into an infinite loop if pad->len == 0,
// so we should pay attention to Hive's future solution to be compatible.
if (len.val <= str.len || pad.len == 0) {
return StringVal(str.ptr, len.val);
if (len.val <= str_char_size || pad.len == 0) {
if (len.val > str_index.size()) {
return StringVal::null();
}
if (len.val == str_index.size()) {
return StringVal(str.ptr, len.val);
}
return StringVal(str.ptr, str_index[len.val]);
}

// TODO pengyubing
// StringVal result = StringVal::create_temp_string_val(context, len.val);
StringVal result(context, len.val);
int32_t pad_byte_len = 0;
int32_t pad_times = (len.val - str_char_size) / pad_char_size;
int32_t pad_remainder = (len.val - str_char_size) % pad_char_size;
pad_byte_len = pad_times * pad.len;
pad_byte_len += pad_index[pad_remainder];
int32_t byte_len = str.len + pad_byte_len;
StringVal result(context, byte_len);
if (UNLIKELY(result.is_null)) {
return result;
}
memcpy(result.ptr, str.ptr, str.len);

// Append chars of pad until desired length
uint8_t* ptr = result.ptr;
int pad_index = 0;
int pad_idx = 0;
int result_len = str.len;
while (result_len < len.val) {
ptr[result_len++] = pad.ptr[pad_index++];
pad_index = pad_index % pad.len;
while (result_len < byte_len) {
ptr[result_len++] = pad.ptr[pad_idx++];
pad_idx = pad_idx % pad.len;
}
return result;
}
Expand Down Expand Up @@ -295,7 +338,6 @@ IntVal StringFunctions::char_utf8_length(FunctionContext* context, const StringV
return IntVal::null();
}
size_t char_len = 0;
std::vector<size_t> index;
for (size_t i = 0, char_size = 0; i < str.len; i += char_size) {
char_size = get_utf8_byte_length((unsigned)(str.ptr)[i]);
++char_len;
Expand Down Expand Up @@ -412,11 +454,24 @@ IntVal StringFunctions::instr(
if (str.is_null || substr.is_null) {
return IntVal::null();
}
if (substr.len == 0) {
return IntVal(1);
}
StringValue str_sv = StringValue::from_string_val(str);
StringValue substr_sv = StringValue::from_string_val(substr);
StringSearch search(&substr_sv);
// Hive returns positions starting from 1.
return IntVal(search.search(&str_sv) + 1);
int loc = search.search(&str_sv);
if (loc > 0) {
size_t char_len = 0;
for (size_t i = 0, char_size = 0; i < loc; i += char_size) {
char_size = get_utf8_byte_length((unsigned)(str.ptr)[i]);
++char_len;
}
loc = char_len;
}

return IntVal(loc + 1);
}

IntVal StringFunctions::locate(
Expand All @@ -430,20 +485,34 @@ IntVal StringFunctions::locate_pos(
if (str.is_null || substr.is_null || start_pos.is_null) {
return IntVal::null();
}
if (substr.len == 0) {
if (str.len == 0 && start_pos.val > 1) {
return IntVal(0);
}
return IntVal(start_pos.val);
}
// Hive returns 0 for *start_pos <= 0,
// but throws an exception for *start_pos > str->len.
// Since returning 0 seems to be Hive's error condition, return 0.
if (start_pos.val <= 0 || start_pos.val > str.len) {
std::vector<size_t> index;
size_t char_len = get_char_len(str, &index);
if (start_pos.val <= 0 || start_pos.val > str.len || start_pos.val > char_len) {
return IntVal(0);
}
StringValue substr_sv = StringValue::from_string_val(substr);
StringSearch search(&substr_sv);
// Input start_pos.val starts from 1.
StringValue adjusted_str(
reinterpret_cast<char*>(str.ptr) + start_pos.val - 1, str.len - start_pos.val + 1);
reinterpret_cast<char*>(str.ptr) + index[start_pos.val - 1], str.len - index[start_pos.val - 1]);
int32_t match_pos = search.search(&adjusted_str);
if (match_pos >= 0) {
// Hive returns the position in the original string starting from 1.
size_t char_len = 0;
for (size_t i = 0, char_size = 0; i < match_pos; i += char_size) {
char_size = get_utf8_byte_length((unsigned)(adjusted_str.ptr)[i]);
++char_len;
}
match_pos = char_len;
return IntVal(start_pos.val + match_pos);
} else {
return IntVal(0);
Expand Down
89 changes: 86 additions & 3 deletions be/test/exprs/string_functions_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,6 @@ TEST_F(StringFunctionsTest, money_format_large_int) {
ss << str;
__int128 value;
ss >> value;

std::cout << "value: " << value << std::endl;

StringVal result = StringFunctions::money_format(context, doris_udf::LargeIntVal(value));
StringVal expected = AnyValUtil::from_string_temp(context, std::string("170,141,183,460,469,231,731,687,303,715,884,105,727.00"));
ASSERT_EQ(expected, result);
Expand Down Expand Up @@ -361,6 +358,92 @@ TEST_F(StringFunctionsTest, append_trailing_char_if_absent) {
StringVal("a"), StringVal("abc")));
}

TEST_F(StringFunctionsTest, instr) {
doris_udf::FunctionContext* context = new doris_udf::FunctionContext();
ASSERT_EQ(IntVal(4), StringFunctions::instr(context, StringVal("foobarbar"), StringVal("bar")));
ASSERT_EQ(IntVal(0), StringFunctions::instr(context, StringVal("foobar"), StringVal("xbar")));
ASSERT_EQ(IntVal(2), StringFunctions::instr(context, StringVal("123456234"), StringVal("234")));
ASSERT_EQ(IntVal(0), StringFunctions::instr(context, StringVal("123456"), StringVal("567")));
ASSERT_EQ(IntVal(2), StringFunctions::instr(context, StringVal("1.234"), StringVal(".234")));
ASSERT_EQ(IntVal(1), StringFunctions::instr(context, StringVal("1.234"), StringVal("")));
ASSERT_EQ(IntVal(0), StringFunctions::instr(context, StringVal(""), StringVal("123")));
ASSERT_EQ(IntVal(1), StringFunctions::instr(context, StringVal(""), StringVal("")));
ASSERT_EQ(IntVal(3), StringFunctions::instr(context, StringVal("你好世界"), StringVal("世界")));
ASSERT_EQ(IntVal(0), StringFunctions::instr(context, StringVal("你好世界"), StringVal("您好")));
ASSERT_EQ(IntVal(3), StringFunctions::instr(context, StringVal("你好abc"), StringVal("a")));
ASSERT_EQ(IntVal(3), StringFunctions::instr(context, StringVal("你好abc"), StringVal("abc")));
ASSERT_EQ(IntVal::null(), StringFunctions::instr(context, StringVal::null(), StringVal("2")));
ASSERT_EQ(IntVal::null(), StringFunctions::instr(context, StringVal(""), StringVal::null()));
ASSERT_EQ(IntVal::null(), StringFunctions::instr(context, StringVal::null(), StringVal::null()));
}

TEST_F(StringFunctionsTest, locate) {
doris_udf::FunctionContext* context = new doris_udf::FunctionContext();
ASSERT_EQ(IntVal(4), StringFunctions::locate(context, StringVal("bar"), StringVal("foobarbar")));
ASSERT_EQ(IntVal(0), StringFunctions::locate(context, StringVal("xbar"), StringVal("foobar")));
ASSERT_EQ(IntVal(2), StringFunctions::locate(context, StringVal("234"), StringVal("123456234")));
ASSERT_EQ(IntVal(0), StringFunctions::locate(context, StringVal("567"), StringVal("123456")));
ASSERT_EQ(IntVal(2), StringFunctions::locate(context, StringVal(".234"), StringVal("1.234")));
ASSERT_EQ(IntVal(1), StringFunctions::locate(context, StringVal(""), StringVal("1.234")));
ASSERT_EQ(IntVal(0), StringFunctions::locate(context, StringVal("123"), StringVal("")));
ASSERT_EQ(IntVal(1), StringFunctions::locate(context, StringVal(""), StringVal("")));
ASSERT_EQ(IntVal(3), StringFunctions::locate(context, StringVal("世界"), StringVal("你好世界")));
ASSERT_EQ(IntVal(0), StringFunctions::locate(context, StringVal("您好"), StringVal("你好世界")));
ASSERT_EQ(IntVal(3), StringFunctions::locate(context, StringVal("a"), StringVal("你好abc")));
ASSERT_EQ(IntVal(3), StringFunctions::locate(context, StringVal("abc"), StringVal("你好abc")));
ASSERT_EQ(IntVal::null(), StringFunctions::locate(context, StringVal::null(), StringVal("2")));
ASSERT_EQ(IntVal::null(), StringFunctions::locate(context, StringVal(""), StringVal::null()));
ASSERT_EQ(IntVal::null(), StringFunctions::locate(context, StringVal::null(), StringVal::null()));
}

TEST_F(StringFunctionsTest, locate_pos) {
doris_udf::FunctionContext* context = new doris_udf::FunctionContext();
ASSERT_EQ(IntVal(7), StringFunctions::locate_pos(context, StringVal("bar"), StringVal("foobarbar"), IntVal(5)));
ASSERT_EQ(IntVal(0), StringFunctions::locate_pos(context, StringVal("xbar"), StringVal("foobar"), IntVal(1)));
ASSERT_EQ(IntVal(2), StringFunctions::locate_pos(context, StringVal(""), StringVal("foobar"), IntVal(2)));
ASSERT_EQ(IntVal(0), StringFunctions::locate_pos(context, StringVal("foobar"), StringVal(""), IntVal(1)));
ASSERT_EQ(IntVal(0), StringFunctions::locate_pos(context, StringVal(""), StringVal(""), IntVal(2)));
ASSERT_EQ(IntVal(0), StringFunctions::locate_pos(context, StringVal("A"), StringVal("AAAAAA"), IntVal(0)));
ASSERT_EQ(IntVal(0), StringFunctions::locate_pos(context, StringVal("A"), StringVal("大A写的A"), IntVal(0)));
ASSERT_EQ(IntVal(2), StringFunctions::locate_pos(context, StringVal("A"), StringVal("大A写的A"), IntVal(1)));
ASSERT_EQ(IntVal(2), StringFunctions::locate_pos(context, StringVal("A"), StringVal("大A写的A"), IntVal(2)));
ASSERT_EQ(IntVal(5), StringFunctions::locate_pos(context, StringVal("A"), StringVal("大A写的A"), IntVal(3)));
ASSERT_EQ(IntVal(7), StringFunctions::locate_pos(context, StringVal("BaR"), StringVal("foobarBaR"), IntVal(5)));
ASSERT_EQ(IntVal::null(), StringFunctions::locate_pos(context, StringVal::null(), StringVal("2"), IntVal(1)));
ASSERT_EQ(IntVal::null(), StringFunctions::locate_pos(context, StringVal(""), StringVal::null(), IntVal(4)));
ASSERT_EQ(IntVal::null(), StringFunctions::locate_pos(context, StringVal::null(), StringVal::null(), IntVal(4)));
ASSERT_EQ(IntVal::null(), StringFunctions::locate_pos(context, StringVal::null(), StringVal::null(), IntVal(-1)));
}

TEST_F(StringFunctionsTest, lpad) {
ASSERT_EQ(StringVal("???hi"), StringFunctions::lpad(ctx, StringVal("hi"), IntVal(5), StringVal("?")));
ASSERT_EQ(StringVal("g8%7IgY%AHx7luNtf8Kh"), StringFunctions::lpad(ctx, StringVal("g8%7IgY%AHx7luNtf8Kh"), IntVal(20), StringVal("")));
ASSERT_EQ(StringVal("h"), StringFunctions::lpad(ctx, StringVal("hi"), IntVal(1), StringVal("?")));
ASSERT_EQ(StringVal("你"), StringFunctions::lpad(ctx, StringVal("你好"), IntVal(1), StringVal("?")));
ASSERT_EQ(StringVal(""), StringFunctions::lpad(ctx, StringVal("hi"), IntVal(0), StringVal("?")));
ASSERT_EQ(StringVal::null(), StringFunctions::lpad(ctx, StringVal("hi"), IntVal(-1), StringVal("?")));
ASSERT_EQ(StringVal("h"), StringFunctions::lpad(ctx, StringVal("hi"), IntVal(1), StringVal("")));
ASSERT_EQ(StringVal::null(), StringFunctions::lpad(ctx, StringVal("hi"), IntVal(5), StringVal("")));
ASSERT_EQ(StringVal("abahi"), StringFunctions::lpad(ctx, StringVal("hi"), IntVal(5), StringVal("ab")));
ASSERT_EQ(StringVal("ababhi"), StringFunctions::lpad(ctx, StringVal("hi"), IntVal(6), StringVal("ab")));
ASSERT_EQ(StringVal("呵呵呵hi"), StringFunctions::lpad(ctx, StringVal("hi"), IntVal(5), StringVal("呵呵")));
ASSERT_EQ(StringVal("hih呵呵"), StringFunctions::lpad(ctx, StringVal("呵呵"), IntVal(5), StringVal("hi")));
}

TEST_F(StringFunctionsTest, rpad) {
ASSERT_EQ(StringVal("hi???"), StringFunctions::rpad(ctx, StringVal("hi"), IntVal(5), StringVal("?")));
ASSERT_EQ(StringVal("g8%7IgY%AHx7luNtf8Kh"), StringFunctions::rpad(ctx, StringVal("g8%7IgY%AHx7luNtf8Kh"), IntVal(20), StringVal("")));
ASSERT_EQ(StringVal("h"), StringFunctions::rpad(ctx, StringVal("hi"), IntVal(1), StringVal("?")));
ASSERT_EQ(StringVal("你"), StringFunctions::rpad(ctx, StringVal("你好"), IntVal(1), StringVal("?")));
ASSERT_EQ(StringVal(""), StringFunctions::rpad(ctx, StringVal("hi"), IntVal(0), StringVal("?")));
ASSERT_EQ(StringVal::null(), StringFunctions::rpad(ctx, StringVal("hi"), IntVal(-1), StringVal("?")));
ASSERT_EQ(StringVal("h"), StringFunctions::rpad(ctx, StringVal("hi"), IntVal(1), StringVal("")));
ASSERT_EQ(StringVal::null(), StringFunctions::rpad(ctx, StringVal("hi"), IntVal(5), StringVal("")));
ASSERT_EQ(StringVal("hiaba"), StringFunctions::rpad(ctx, StringVal("hi"), IntVal(5), StringVal("ab")));
ASSERT_EQ(StringVal("hiabab"), StringFunctions::rpad(ctx, StringVal("hi"), IntVal(6), StringVal("ab")));
ASSERT_EQ(StringVal("hi呵呵呵"), StringFunctions::rpad(ctx, StringVal("hi"), IntVal(5), StringVal("呵呵")));
ASSERT_EQ(StringVal("呵呵hih"), StringFunctions::rpad(ctx, StringVal("呵呵"), IntVal(5), StringVal("hi")));
}
}

int main(int argc, char** argv) {
Expand Down
1 change: 1 addition & 0 deletions docs/.vuepress/sidebar/en.js
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ module.exports = [
"regexp_replace",
"repeat",
"right",
"rpad",
"split_part",
"starts_with",
"strleft",
Expand Down
1 change: 1 addition & 0 deletions docs/.vuepress/sidebar/zh-CN.js
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ module.exports = [
"regexp_replace",
"repeat",
"right",
"rpad",
"split_part",
"starts_with",
"strleft",
Expand Down
4 changes: 2 additions & 2 deletions docs/en/sql-reference/sql-functions/string-functions/lpad.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ under the License.
## Description
### Syntax

'VARCHAR lpad (VARCHAR str., INT len, VARCHAR pad)'
'VARCHAR lpad (VARCHAR str, INT len, VARCHAR pad)'


Returns a string of length len in str, starting with the initials. If len is longer than str, pad characters are added to STR until the length of the string reaches len. If len is less than str's length, the function is equivalent to truncating STR strings and returning only strings of len's length.
Returns a string of length len in str, starting with the initials. If len is longer than str, pad characters are added to STR until the length of the string reaches len. If len is less than str's length, the function is equivalent to truncating STR strings and returning only strings of len's length. The len is character length not the bye size.

## example

Expand Down
54 changes: 54 additions & 0 deletions docs/en/sql-reference/sql-functions/string-functions/rpad.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
---
{
"title": "rpad",
"language": "en"
}
---

<!--
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->

# rpad
## Description
### Syntax

'VARCHAR rpad (VARCHAR str, INT len, VARCHAR pad)'


Returns a string of length len in str, starting with the initials. If len is longer than str, pad characters are added to the right of STR until the length of the string reaches len. If len is less than str's length, the function is equivalent to truncating STR strings and returning only strings of len's length. The len is character length not the bye size.

## example

```
mysql> SELECT rpad("hi", 5, "xy");
+---------------------+
| rpad('hi', 5, 'xy') |
+---------------------+
| hixyx |
+---------------------+

mysql> SELECT rpad("hi", 1, "xy");
+---------------------+
| rpad('hi', 1, 'xy') |
+---------------------+
| h |
+---------------------+
```
##keyword
RPAD
Loading