diff --git a/be/src/exprs/string_functions.cpp b/be/src/exprs/string_functions.cpp index b02bf56837b1f7..998794c8392345 100644 --- a/be/src/exprs/string_functions.cpp +++ b/be/src/exprs/string_functions.cpp @@ -49,6 +49,15 @@ size_t get_utf8_byte_length(unsigned char byte) { } return char_size; } +size_t get_char_len(const StringVal& str, std::vector* str_index) { + size_t char_len = 0; + for (size_t i = 0, char_size = 0; i < str.len; i += char_size) { + char_size = get_utf8_byte_length((unsigned)(str.ptr)[i]); + str_index->push_back(i); + ++char_len; + } + return char_len; +} // This behaves identically to the mysql implementation, namely: // - 1-indexed positions @@ -73,8 +82,7 @@ StringVal StringFunctions::substring( std::vector index; for (size_t i = 0, char_size = 0; i < str.len; i += char_size) { char_size = get_utf8_byte_length((unsigned)(str.ptr)[i]); - index.push_back(byte_pos); - byte_pos += char_size; + index.push_back(i); if (pos.val > 0 && index.size() > pos.val + len.val) { break; } @@ -196,28 +204,45 @@ StringVal StringFunctions::lpad( if (str.is_null || len.is_null || pad.is_null || len.val < 0) { return StringVal::null(); } + + std::vector str_index; + size_t str_char_size = get_char_len(str, &str_index); + std::vector pad_index; + size_t pad_char_size = get_char_len(pad, &pad_index); + // Corner cases: Shrink the original string, or leave it alone. // TODO: Hive seems to go into an infinite loop if pad.len == 0, // so we should pay attention to Hive's future solution to be compatible. - if (len.val <= str.len || pad.len == 0) { - return StringVal(str.ptr, len.val); + if (len.val <= str_char_size || pad.len == 0) { + if (len.val > str_index.size()) { + return StringVal::null(); + } + if (len.val == str_index.size()) { + return StringVal(str.ptr, len.val); + } + return StringVal(str.ptr, str_index[len.val]); } // TODO pengyubing // StringVal result = StringVal::create_temp_string_val(context, len.val); - StringVal result(context, len.val); + int32_t pad_byte_len = 0; + int32_t pad_times = (len.val - str_char_size) / pad_char_size; + int32_t pad_remainder = (len.val - str_char_size) % pad_char_size; + pad_byte_len = pad_times * pad.len; + pad_byte_len += pad_index[pad_remainder]; + int32_t byte_len = str.len + pad_byte_len; + StringVal result(context, byte_len); if (result.is_null) { return result; } - int padded_prefix_len = len.val - str.len; - int pad_index = 0; + int pad_idx = 0; int result_index = 0; uint8_t* ptr = result.ptr; // Prepend chars of pad. - while (result_index < padded_prefix_len) { - ptr[result_index++] = pad.ptr[pad_index++]; - pad_index = pad_index % pad.len; + while (result_index < pad_byte_len) { + ptr[result_index++] = pad.ptr[pad_idx++]; + pad_idx = pad_idx % pad.len; } // Append given string. @@ -231,16 +256,34 @@ StringVal StringFunctions::rpad( if (str.is_null || len.is_null || pad.is_null || len.val < 0) { return StringVal::null(); } + + std::vector str_index; + size_t str_char_size = get_char_len(str, &str_index); + std::vector pad_index; + size_t pad_char_size = get_char_len(pad, &pad_index); + // Corner cases: Shrink the original string, or leave it alone. // TODO: Hive seems to go into an infinite loop if pad->len == 0, // so we should pay attention to Hive's future solution to be compatible. - if (len.val <= str.len || pad.len == 0) { - return StringVal(str.ptr, len.val); + if (len.val <= str_char_size || pad.len == 0) { + if (len.val > str_index.size()) { + return StringVal::null(); + } + if (len.val == str_index.size()) { + return StringVal(str.ptr, len.val); + } + return StringVal(str.ptr, str_index[len.val]); } // TODO pengyubing // StringVal result = StringVal::create_temp_string_val(context, len.val); - StringVal result(context, len.val); + int32_t pad_byte_len = 0; + int32_t pad_times = (len.val - str_char_size) / pad_char_size; + int32_t pad_remainder = (len.val - str_char_size) % pad_char_size; + pad_byte_len = pad_times * pad.len; + pad_byte_len += pad_index[pad_remainder]; + int32_t byte_len = str.len + pad_byte_len; + StringVal result(context, byte_len); if (UNLIKELY(result.is_null)) { return result; } @@ -248,11 +291,11 @@ StringVal StringFunctions::rpad( // Append chars of pad until desired length uint8_t* ptr = result.ptr; - int pad_index = 0; + int pad_idx = 0; int result_len = str.len; - while (result_len < len.val) { - ptr[result_len++] = pad.ptr[pad_index++]; - pad_index = pad_index % pad.len; + while (result_len < byte_len) { + ptr[result_len++] = pad.ptr[pad_idx++]; + pad_idx = pad_idx % pad.len; } return result; } @@ -295,7 +338,6 @@ IntVal StringFunctions::char_utf8_length(FunctionContext* context, const StringV return IntVal::null(); } size_t char_len = 0; - std::vector index; for (size_t i = 0, char_size = 0; i < str.len; i += char_size) { char_size = get_utf8_byte_length((unsigned)(str.ptr)[i]); ++char_len; @@ -412,11 +454,24 @@ IntVal StringFunctions::instr( if (str.is_null || substr.is_null) { return IntVal::null(); } + if (substr.len == 0) { + return IntVal(1); + } StringValue str_sv = StringValue::from_string_val(str); StringValue substr_sv = StringValue::from_string_val(substr); StringSearch search(&substr_sv); // Hive returns positions starting from 1. - return IntVal(search.search(&str_sv) + 1); + int loc = search.search(&str_sv); + if (loc > 0) { + size_t char_len = 0; + for (size_t i = 0, char_size = 0; i < loc; i += char_size) { + char_size = get_utf8_byte_length((unsigned)(str.ptr)[i]); + ++char_len; + } + loc = char_len; + } + + return IntVal(loc + 1); } IntVal StringFunctions::locate( @@ -430,20 +485,34 @@ IntVal StringFunctions::locate_pos( if (str.is_null || substr.is_null || start_pos.is_null) { return IntVal::null(); } + if (substr.len == 0) { + if (str.len == 0 && start_pos.val > 1) { + return IntVal(0); + } + return IntVal(start_pos.val); + } // Hive returns 0 for *start_pos <= 0, // but throws an exception for *start_pos > str->len. // Since returning 0 seems to be Hive's error condition, return 0. - if (start_pos.val <= 0 || start_pos.val > str.len) { + std::vector index; + size_t char_len = get_char_len(str, &index); + if (start_pos.val <= 0 || start_pos.val > str.len || start_pos.val > char_len) { return IntVal(0); } StringValue substr_sv = StringValue::from_string_val(substr); StringSearch search(&substr_sv); // Input start_pos.val starts from 1. StringValue adjusted_str( - reinterpret_cast(str.ptr) + start_pos.val - 1, str.len - start_pos.val + 1); + reinterpret_cast(str.ptr) + index[start_pos.val - 1], str.len - index[start_pos.val - 1]); int32_t match_pos = search.search(&adjusted_str); if (match_pos >= 0) { // Hive returns the position in the original string starting from 1. + size_t char_len = 0; + for (size_t i = 0, char_size = 0; i < match_pos; i += char_size) { + char_size = get_utf8_byte_length((unsigned)(adjusted_str.ptr)[i]); + ++char_len; + } + match_pos = char_len; return IntVal(start_pos.val + match_pos); } else { return IntVal(0); diff --git a/be/test/exprs/string_functions_test.cpp b/be/test/exprs/string_functions_test.cpp index 2670f576eb7a60..7320759834946d 100644 --- a/be/test/exprs/string_functions_test.cpp +++ b/be/test/exprs/string_functions_test.cpp @@ -68,9 +68,6 @@ TEST_F(StringFunctionsTest, money_format_large_int) { ss << str; __int128 value; ss >> value; - - std::cout << "value: " << value << std::endl; - StringVal result = StringFunctions::money_format(context, doris_udf::LargeIntVal(value)); StringVal expected = AnyValUtil::from_string_temp(context, std::string("170,141,183,460,469,231,731,687,303,715,884,105,727.00")); ASSERT_EQ(expected, result); @@ -361,6 +358,92 @@ TEST_F(StringFunctionsTest, append_trailing_char_if_absent) { StringVal("a"), StringVal("abc"))); } +TEST_F(StringFunctionsTest, instr) { + doris_udf::FunctionContext* context = new doris_udf::FunctionContext(); + ASSERT_EQ(IntVal(4), StringFunctions::instr(context, StringVal("foobarbar"), StringVal("bar"))); + ASSERT_EQ(IntVal(0), StringFunctions::instr(context, StringVal("foobar"), StringVal("xbar"))); + ASSERT_EQ(IntVal(2), StringFunctions::instr(context, StringVal("123456234"), StringVal("234"))); + ASSERT_EQ(IntVal(0), StringFunctions::instr(context, StringVal("123456"), StringVal("567"))); + ASSERT_EQ(IntVal(2), StringFunctions::instr(context, StringVal("1.234"), StringVal(".234"))); + ASSERT_EQ(IntVal(1), StringFunctions::instr(context, StringVal("1.234"), StringVal(""))); + ASSERT_EQ(IntVal(0), StringFunctions::instr(context, StringVal(""), StringVal("123"))); + ASSERT_EQ(IntVal(1), StringFunctions::instr(context, StringVal(""), StringVal(""))); + ASSERT_EQ(IntVal(3), StringFunctions::instr(context, StringVal("你好世界"), StringVal("世界"))); + ASSERT_EQ(IntVal(0), StringFunctions::instr(context, StringVal("你好世界"), StringVal("您好"))); + ASSERT_EQ(IntVal(3), StringFunctions::instr(context, StringVal("你好abc"), StringVal("a"))); + ASSERT_EQ(IntVal(3), StringFunctions::instr(context, StringVal("你好abc"), StringVal("abc"))); + ASSERT_EQ(IntVal::null(), StringFunctions::instr(context, StringVal::null(), StringVal("2"))); + ASSERT_EQ(IntVal::null(), StringFunctions::instr(context, StringVal(""), StringVal::null())); + ASSERT_EQ(IntVal::null(), StringFunctions::instr(context, StringVal::null(), StringVal::null())); +} + +TEST_F(StringFunctionsTest, locate) { + doris_udf::FunctionContext* context = new doris_udf::FunctionContext(); + ASSERT_EQ(IntVal(4), StringFunctions::locate(context, StringVal("bar"), StringVal("foobarbar"))); + ASSERT_EQ(IntVal(0), StringFunctions::locate(context, StringVal("xbar"), StringVal("foobar"))); + ASSERT_EQ(IntVal(2), StringFunctions::locate(context, StringVal("234"), StringVal("123456234"))); + ASSERT_EQ(IntVal(0), StringFunctions::locate(context, StringVal("567"), StringVal("123456"))); + ASSERT_EQ(IntVal(2), StringFunctions::locate(context, StringVal(".234"), StringVal("1.234"))); + ASSERT_EQ(IntVal(1), StringFunctions::locate(context, StringVal(""), StringVal("1.234"))); + ASSERT_EQ(IntVal(0), StringFunctions::locate(context, StringVal("123"), StringVal(""))); + ASSERT_EQ(IntVal(1), StringFunctions::locate(context, StringVal(""), StringVal(""))); + ASSERT_EQ(IntVal(3), StringFunctions::locate(context, StringVal("世界"), StringVal("你好世界"))); + ASSERT_EQ(IntVal(0), StringFunctions::locate(context, StringVal("您好"), StringVal("你好世界"))); + ASSERT_EQ(IntVal(3), StringFunctions::locate(context, StringVal("a"), StringVal("你好abc"))); + ASSERT_EQ(IntVal(3), StringFunctions::locate(context, StringVal("abc"), StringVal("你好abc"))); + ASSERT_EQ(IntVal::null(), StringFunctions::locate(context, StringVal::null(), StringVal("2"))); + ASSERT_EQ(IntVal::null(), StringFunctions::locate(context, StringVal(""), StringVal::null())); + ASSERT_EQ(IntVal::null(), StringFunctions::locate(context, StringVal::null(), StringVal::null())); +} + +TEST_F(StringFunctionsTest, locate_pos) { + doris_udf::FunctionContext* context = new doris_udf::FunctionContext(); + ASSERT_EQ(IntVal(7), StringFunctions::locate_pos(context, StringVal("bar"), StringVal("foobarbar"), IntVal(5))); + ASSERT_EQ(IntVal(0), StringFunctions::locate_pos(context, StringVal("xbar"), StringVal("foobar"), IntVal(1))); + ASSERT_EQ(IntVal(2), StringFunctions::locate_pos(context, StringVal(""), StringVal("foobar"), IntVal(2))); + ASSERT_EQ(IntVal(0), StringFunctions::locate_pos(context, StringVal("foobar"), StringVal(""), IntVal(1))); + ASSERT_EQ(IntVal(0), StringFunctions::locate_pos(context, StringVal(""), StringVal(""), IntVal(2))); + ASSERT_EQ(IntVal(0), StringFunctions::locate_pos(context, StringVal("A"), StringVal("AAAAAA"), IntVal(0))); + ASSERT_EQ(IntVal(0), StringFunctions::locate_pos(context, StringVal("A"), StringVal("大A写的A"), IntVal(0))); + ASSERT_EQ(IntVal(2), StringFunctions::locate_pos(context, StringVal("A"), StringVal("大A写的A"), IntVal(1))); + ASSERT_EQ(IntVal(2), StringFunctions::locate_pos(context, StringVal("A"), StringVal("大A写的A"), IntVal(2))); + ASSERT_EQ(IntVal(5), StringFunctions::locate_pos(context, StringVal("A"), StringVal("大A写的A"), IntVal(3))); + ASSERT_EQ(IntVal(7), StringFunctions::locate_pos(context, StringVal("BaR"), StringVal("foobarBaR"), IntVal(5))); + ASSERT_EQ(IntVal::null(), StringFunctions::locate_pos(context, StringVal::null(), StringVal("2"), IntVal(1))); + ASSERT_EQ(IntVal::null(), StringFunctions::locate_pos(context, StringVal(""), StringVal::null(), IntVal(4))); + ASSERT_EQ(IntVal::null(), StringFunctions::locate_pos(context, StringVal::null(), StringVal::null(), IntVal(4))); + ASSERT_EQ(IntVal::null(), StringFunctions::locate_pos(context, StringVal::null(), StringVal::null(), IntVal(-1))); +} + +TEST_F(StringFunctionsTest, lpad) { + ASSERT_EQ(StringVal("???hi"), StringFunctions::lpad(ctx, StringVal("hi"), IntVal(5), StringVal("?"))); + ASSERT_EQ(StringVal("g8%7IgY%AHx7luNtf8Kh"), StringFunctions::lpad(ctx, StringVal("g8%7IgY%AHx7luNtf8Kh"), IntVal(20), StringVal(""))); + ASSERT_EQ(StringVal("h"), StringFunctions::lpad(ctx, StringVal("hi"), IntVal(1), StringVal("?"))); + ASSERT_EQ(StringVal("你"), StringFunctions::lpad(ctx, StringVal("你好"), IntVal(1), StringVal("?"))); + ASSERT_EQ(StringVal(""), StringFunctions::lpad(ctx, StringVal("hi"), IntVal(0), StringVal("?"))); + ASSERT_EQ(StringVal::null(), StringFunctions::lpad(ctx, StringVal("hi"), IntVal(-1), StringVal("?"))); + ASSERT_EQ(StringVal("h"), StringFunctions::lpad(ctx, StringVal("hi"), IntVal(1), StringVal(""))); + ASSERT_EQ(StringVal::null(), StringFunctions::lpad(ctx, StringVal("hi"), IntVal(5), StringVal(""))); + ASSERT_EQ(StringVal("abahi"), StringFunctions::lpad(ctx, StringVal("hi"), IntVal(5), StringVal("ab"))); + ASSERT_EQ(StringVal("ababhi"), StringFunctions::lpad(ctx, StringVal("hi"), IntVal(6), StringVal("ab"))); + ASSERT_EQ(StringVal("呵呵呵hi"), StringFunctions::lpad(ctx, StringVal("hi"), IntVal(5), StringVal("呵呵"))); + ASSERT_EQ(StringVal("hih呵呵"), StringFunctions::lpad(ctx, StringVal("呵呵"), IntVal(5), StringVal("hi"))); +} + +TEST_F(StringFunctionsTest, rpad) { + ASSERT_EQ(StringVal("hi???"), StringFunctions::rpad(ctx, StringVal("hi"), IntVal(5), StringVal("?"))); + ASSERT_EQ(StringVal("g8%7IgY%AHx7luNtf8Kh"), StringFunctions::rpad(ctx, StringVal("g8%7IgY%AHx7luNtf8Kh"), IntVal(20), StringVal(""))); + ASSERT_EQ(StringVal("h"), StringFunctions::rpad(ctx, StringVal("hi"), IntVal(1), StringVal("?"))); + ASSERT_EQ(StringVal("你"), StringFunctions::rpad(ctx, StringVal("你好"), IntVal(1), StringVal("?"))); + ASSERT_EQ(StringVal(""), StringFunctions::rpad(ctx, StringVal("hi"), IntVal(0), StringVal("?"))); + ASSERT_EQ(StringVal::null(), StringFunctions::rpad(ctx, StringVal("hi"), IntVal(-1), StringVal("?"))); + ASSERT_EQ(StringVal("h"), StringFunctions::rpad(ctx, StringVal("hi"), IntVal(1), StringVal(""))); + ASSERT_EQ(StringVal::null(), StringFunctions::rpad(ctx, StringVal("hi"), IntVal(5), StringVal(""))); + ASSERT_EQ(StringVal("hiaba"), StringFunctions::rpad(ctx, StringVal("hi"), IntVal(5), StringVal("ab"))); + ASSERT_EQ(StringVal("hiabab"), StringFunctions::rpad(ctx, StringVal("hi"), IntVal(6), StringVal("ab"))); + ASSERT_EQ(StringVal("hi呵呵呵"), StringFunctions::rpad(ctx, StringVal("hi"), IntVal(5), StringVal("呵呵"))); + ASSERT_EQ(StringVal("呵呵hih"), StringFunctions::rpad(ctx, StringVal("呵呵"), IntVal(5), StringVal("hi"))); +} } int main(int argc, char** argv) { diff --git a/docs/.vuepress/sidebar/en.js b/docs/.vuepress/sidebar/en.js index e7ade2c6edb4fb..0a653f8045bdf6 100644 --- a/docs/.vuepress/sidebar/en.js +++ b/docs/.vuepress/sidebar/en.js @@ -221,6 +221,7 @@ module.exports = [ "regexp_replace", "repeat", "right", + "rpad", "split_part", "starts_with", "strleft", diff --git a/docs/.vuepress/sidebar/zh-CN.js b/docs/.vuepress/sidebar/zh-CN.js index 880c3404170e18..d386b30c222f6f 100644 --- a/docs/.vuepress/sidebar/zh-CN.js +++ b/docs/.vuepress/sidebar/zh-CN.js @@ -233,6 +233,7 @@ module.exports = [ "regexp_replace", "repeat", "right", + "rpad", "split_part", "starts_with", "strleft", diff --git a/docs/en/sql-reference/sql-functions/string-functions/lpad.md b/docs/en/sql-reference/sql-functions/string-functions/lpad.md index 1751cffc453f2b..686748a1340a42 100644 --- a/docs/en/sql-reference/sql-functions/string-functions/lpad.md +++ b/docs/en/sql-reference/sql-functions/string-functions/lpad.md @@ -28,10 +28,10 @@ under the License. ## Description ### Syntax -'VARCHAR lpad (VARCHAR str., INT len, VARCHAR pad)' +'VARCHAR lpad (VARCHAR str, INT len, VARCHAR pad)' -Returns a string of length len in str, starting with the initials. If len is longer than str, pad characters are added to STR until the length of the string reaches len. If len is less than str's length, the function is equivalent to truncating STR strings and returning only strings of len's length. +Returns a string of length len in str, starting with the initials. If len is longer than str, pad characters are added to STR until the length of the string reaches len. If len is less than str's length, the function is equivalent to truncating STR strings and returning only strings of len's length. The len is character length not the bye size. ## example diff --git a/docs/en/sql-reference/sql-functions/string-functions/rpad.md b/docs/en/sql-reference/sql-functions/string-functions/rpad.md new file mode 100644 index 00000000000000..232916f47ce1aa --- /dev/null +++ b/docs/en/sql-reference/sql-functions/string-functions/rpad.md @@ -0,0 +1,54 @@ +--- +{ + "title": "rpad", + "language": "en" +} +--- + + + +# rpad +## Description +### Syntax + +'VARCHAR rpad (VARCHAR str, INT len, VARCHAR pad)' + + +Returns a string of length len in str, starting with the initials. If len is longer than str, pad characters are added to the right of STR until the length of the string reaches len. If len is less than str's length, the function is equivalent to truncating STR strings and returning only strings of len's length. The len is character length not the bye size. + +## example + +``` +mysql> SELECT rpad("hi", 5, "xy"); ++---------------------+ +| rpad('hi', 5, 'xy') | ++---------------------+ +| hixyx | ++---------------------+ + +mysql> SELECT rpad("hi", 1, "xy"); ++---------------------+ +| rpad('hi', 1, 'xy') | ++---------------------+ +| h | ++---------------------+ +``` +##keyword +RPAD diff --git a/docs/zh-CN/sql-reference/sql-functions/string-functions/lpad.md b/docs/zh-CN/sql-reference/sql-functions/string-functions/lpad.md index 62ea2cd03136d3..1173c2c5c5c4cf 100644 --- a/docs/zh-CN/sql-reference/sql-functions/string-functions/lpad.md +++ b/docs/zh-CN/sql-reference/sql-functions/string-functions/lpad.md @@ -31,7 +31,7 @@ under the License. `VARCHAR lpad(VARCHAR str, INT len, VARCHAR pad)` -返回 str 中长度为 len(从首字母开始算起)的字符串。如果 len 大于 str 的长度,则在 str 的前面不断补充 pad 字符,直到该字符串的长度达到 len 为止。如果 len 小于 str 的长度,该函数相当于截断 str 字符串,只返回长度为 len 的字符串。 +返回 str 中长度为 len(从首字母开始算起)的字符串。如果 len 大于 str 的长度,则在 str 的前面不断补充 pad 字符,直到该字符串的长度达到 len 为止。如果 len 小于 str 的长度,该函数相当于截断 str 字符串,只返回长度为 len 的字符串。len 指的是字符长度而不是字节长度。 ## example diff --git a/docs/zh-CN/sql-reference/sql-functions/string-functions/rpad.md b/docs/zh-CN/sql-reference/sql-functions/string-functions/rpad.md new file mode 100644 index 00000000000000..28c3fdb961688d --- /dev/null +++ b/docs/zh-CN/sql-reference/sql-functions/string-functions/rpad.md @@ -0,0 +1,54 @@ +--- +{ + "title": "rpad", + "language": "zh-CN" +} +--- + + + +# rpad +## description +### Syntax + +`VARCHAR rpad(VARCHAR str, INT len, VARCHAR pad)` + + +返回 str 中长度为 len(从首字母开始算起)的字符串。如果 len 大于 str 的长度,则在 str 的后面不断补充 pad 字符,直到该字符串的长度达到 len 为止。如果 len 小于 str 的长度,该函数相当于截断 str 字符串,只返回长度为 len 的字符串。len 指的是字符长度而不是字节长度。 + +## example + +``` +mysql> SELECT rpad("hi", 5, "xy"); ++---------------------+ +| rpad('hi', 5, 'xy') | ++---------------------+ +| hixyx | ++---------------------+ + +mysql> SELECT rpad("hi", 1, "xy"); ++---------------------+ +| rpad('hi', 1, 'xy') | ++---------------------+ +| h | ++---------------------+ +``` +##keyword +RPAD