Skip to content

Commit

Permalink
This is an automated cherry-pick of #9507
Browse files Browse the repository at this point in the history
Signed-off-by: ti-chi-bot <ti-community-prow-bot@tidb.io>
  • Loading branch information
gengliqi authored and ti-chi-bot committed Oct 10, 2024
1 parent 50bd2b0 commit 23398ad
Show file tree
Hide file tree
Showing 6 changed files with 566 additions and 71 deletions.
295 changes: 265 additions & 30 deletions dbms/src/Functions/FunctionsString.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1673,6 +1673,7 @@ class FunctionSubstringUTF8 : public IFunction

bool implicit_length = (arguments.size() == 2);

<<<<<<< HEAD
bool is_start_type_valid = getNumberType(block.getByPosition(arguments[1]).type, [&](const auto & start_type, bool) {
using StartType = std::decay_t<decltype(start_type)>;
// Int64 / UInt64
Expand All @@ -1692,6 +1693,48 @@ class FunctionSubstringUTF8 : public IFunction
length = getValueFromLengthField<LengthFieldType>((*block.getByPosition(arguments[2]).column)[0]);
return true;
});
=======
bool is_start_type_valid
= getNumberType(block.getByPosition(arguments[1]).type, [&](const auto & start_type, bool) {
using StartType = std::decay_t<decltype(start_type)>;
using StartFieldType = typename StartType::FieldType;
const ColumnVector<StartFieldType> * column_vector_start
= getInnerColumnVector<StartFieldType>(column_start);
if unlikely (!column_vector_start)
throw Exception(
fmt::format(
"Illegal type {} of argument 2 of function {}",
block.getByPosition(arguments[1]).type->getName(),
getName()),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

// vector const const
if (!column_string->isColumnConst() && column_start->isColumnConst()
&& (implicit_length || block.getByPosition(arguments[2]).column->isColumnConst()))
{
auto [is_positive, start_abs] = getValueFromStartColumn<StartFieldType>(*column_vector_start, 0);
UInt64 length = 0;
if (!implicit_length)
{
bool is_length_type_valid = getNumberType(
block.getByPosition(arguments[2]).type,
[&](const auto & length_type, bool) {
using LengthType = std::decay_t<decltype(length_type)>;
using LengthFieldType = typename LengthType::FieldType;
const ColumnVector<LengthFieldType> * column_vector_length
= getInnerColumnVector<LengthFieldType>(block.getByPosition(arguments[2]).column);
if unlikely (!column_vector_length)
throw Exception(
fmt::format(
"Illegal type {} of argument 3 of function {}",
block.getByPosition(arguments[2]).type->getName(),
getName()),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

length = getValueFromLengthColumn<LengthFieldType>(*column_vector_length, 0);
return true;
});
>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507))

if (!is_length_type_valid)
throw Exception(fmt::format("3nd argument of function {} must have UInt/Int type.", getName()));
Expand All @@ -1704,6 +1747,7 @@ class FunctionSubstringUTF8 : public IFunction
return true;
}

<<<<<<< HEAD
const auto * col = checkAndGetColumn<ColumnString>(column_string.get());
assert(col);
auto col_res = ColumnString::create();
Expand Down Expand Up @@ -1757,6 +1801,80 @@ class FunctionSubstringUTF8 : public IFunction
if (!is_length_type_valid)
throw Exception(fmt::format("3nd argument of function {} must have UInt/Int type.", getName()));
}
=======
const auto * col = checkAndGetColumn<ColumnString>(column_string.get());
assert(col);
auto col_res = ColumnString::create();
getVectorConstConstFunc(implicit_length, is_positive)(
col->getChars(),
col->getOffsets(),
start_abs,
length,
col_res->getChars(),
col_res->getOffsets());
block.getByPosition(result).column = std::move(col_res);
}
else // all other cases are converted to vector vector vector
{
std::function<std::pair<bool, size_t>(size_t)> get_start_func;
if (column_start->isColumnConst())
{
// func always return const value
auto start_const = getValueFromStartColumn<StartFieldType>(*column_vector_start, 0);
get_start_func = [start_const](size_t) {
return start_const;
};
}
else
{
get_start_func = [column_vector_start](size_t i) {
return getValueFromStartColumn<StartFieldType>(*column_vector_start, i);
};
}

// if implicit_length, get_length_func be nil is ok.
std::function<size_t(size_t)> get_length_func;
if (!implicit_length)
{
const ColumnPtr & column_length = block.getByPosition(arguments[2]).column;
bool is_length_type_valid = getNumberType(
block.getByPosition(arguments[2]).type,
[&](const auto & length_type, bool) {
using LengthType = std::decay_t<decltype(length_type)>;
using LengthFieldType = typename LengthType::FieldType;
const ColumnVector<LengthFieldType> * column_vector_length
= getInnerColumnVector<LengthFieldType>(column_length);
if unlikely (!column_vector_length)
throw Exception(
fmt::format(
"Illegal type {} of argument 3 of function {}",
block.getByPosition(arguments[2]).type->getName(),
getName()),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

if (column_length->isColumnConst())
{
// func always return const value
auto length_const
= getValueFromLengthColumn<LengthFieldType>(*column_vector_length, 0);
get_length_func = [length_const](size_t) {
return length_const;
};
}
else
{
get_length_func = [column_vector_length](size_t i) {
return getValueFromLengthColumn<LengthFieldType>(*column_vector_length, i);
};
}
return true;
});

if unlikely (!is_length_type_valid)
throw Exception(
fmt::format("3nd argument of function {} must have UInt/Int type.", getName()));
}
>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507))

// convert to vector if string is const.
ColumnPtr full_column_string = column_string->isColumnConst() ? column_string->convertToFullColumnIfConst() : column_string;
Expand All @@ -1777,10 +1895,38 @@ class FunctionSubstringUTF8 : public IFunction
return true;
});

if (!is_start_type_valid)
if unlikely (!is_start_type_valid)
throw Exception(fmt::format("2nd argument of function {} must have UInt/Int type.", getName()));
}

template <typename Integer>
static const ColumnVector<Integer> * getInnerColumnVector(const ColumnPtr & column)
{
if (column->isColumnConst())
return checkAndGetColumn<ColumnVector<Integer>>(
checkAndGetColumn<ColumnConst>(column.get())->getDataColumnPtr().get());
return checkAndGetColumn<ColumnVector<Integer>>(column.get());
}

template <typename Integer>
static size_t getValueFromLengthColumn(const ColumnVector<Integer> & column, size_t index)
{
Integer val = column.getElement(index);
if constexpr (
std::is_same_v<Integer, Int8> || std::is_same_v<Integer, Int16> || std::is_same_v<Integer, Int32>
|| std::is_same_v<Integer, Int64>)
{
return val < 0 ? 0 : val;
}
else
{
static_assert(
std::is_same_v<Integer, UInt8> || std::is_same_v<Integer, UInt16> || std::is_same_v<Integer, UInt32>
|| std::is_same_v<Integer, UInt64>);
return val;
}
}

private:
using VectorConstConstFunc = std::function<void(
const ColumnString::Chars_t &,
Expand All @@ -1802,51 +1948,45 @@ class FunctionSubstringUTF8 : public IFunction
}
}

template <typename Integer>
static size_t getValueFromLengthField(const Field & length_field)
{
if constexpr (std::is_same_v<Integer, Int64>)
{
Int64 signed_length = length_field.get<Int64>();
return signed_length < 0 ? 0 : signed_length;
}
else
{
static_assert(std::is_same_v<Integer, UInt64>);
return length_field.get<UInt64>();
}
}

// return {is_positive, abs}
template <typename Integer>
static std::pair<bool, size_t> getValueFromStartField(const Field & start_field)
static std::pair<bool, size_t> getValueFromStartColumn(const ColumnVector<Integer> & column, size_t index)
{
if constexpr (std::is_same_v<Integer, Int64>)
Integer val = column.getElement(index);
if constexpr (
std::is_same_v<Integer, Int8> || std::is_same_v<Integer, Int16> || std::is_same_v<Integer, Int32>
|| std::is_same_v<Integer, Int64>)
{
Int64 signed_length = start_field.get<Int64>();

if (signed_length < 0)
{
return {false, static_cast<size_t>(-signed_length)};
}
else
{
return {true, static_cast<size_t>(signed_length)};
}
if (val < 0)
return {false, static_cast<size_t>(-val)};
return {true, static_cast<size_t>(val)};
}
else
{
static_assert(std::is_same_v<Integer, UInt64>);
return {true, start_field.get<UInt64>()};
static_assert(
std::is_same_v<Integer, UInt8> || std::is_same_v<Integer, UInt16> || std::is_same_v<Integer, UInt32>
|| std::is_same_v<Integer, UInt64>);
return {true, val};
}
}

template <typename F>
static bool getNumberType(DataTypePtr type, F && f)
{
return castTypeToEither<
<<<<<<< HEAD
DataTypeInt64,
DataTypeUInt64>(type.get(), std::forward<F>(f));
=======
DataTypeUInt8,
DataTypeUInt16,
DataTypeUInt32,
DataTypeUInt64,
DataTypeInt8,
DataTypeInt16,
DataTypeInt32,
DataTypeInt64>(type.get(), std::forward<F>(f));
>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507))
}
};

Expand Down Expand Up @@ -1891,6 +2031,7 @@ class FunctionRightUTF8 : public IFunction
const ColumnPtr column_string = block.getByPosition(arguments[0]).column;
const ColumnPtr column_length = block.getByPosition(arguments[1]).column;

<<<<<<< HEAD
bool is_length_type_valid = getLengthType(block.getByPosition(arguments[1]).type, [&](const auto & length_type, bool) {
using LengthType = std::decay_t<decltype(length_type)>;
// Int64 / UInt64
Expand All @@ -1903,6 +2044,33 @@ class FunctionRightUTF8 : public IFunction
{
// vector const
size_t length = getValueFromLengthField<LengthFieldType>((*column_length)[0]);
=======
bool is_length_type_valid
= getLengthType(block.getByPosition(arguments[1]).type, [&](const auto & length_type, bool) {
using LengthType = std::decay_t<decltype(length_type)>;
using LengthFieldType = typename LengthType::FieldType;

const ColumnVector<LengthFieldType> * column_vector_length
= FunctionSubstringUTF8::getInnerColumnVector<LengthFieldType>(column_length);
if unlikely (!column_vector_length)
throw Exception(
fmt::format(
"Illegal type {} of argument 2 of function {}",
block.getByPosition(arguments[1]).type->getName(),
getName()),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);


auto col_res = ColumnString::create();
if (const auto * col_string = checkAndGetColumn<ColumnString>(column_string.get()))
{
if (column_length->isColumnConst())
{
// vector const
size_t length = FunctionSubstringUTF8::getValueFromLengthColumn<LengthFieldType>(
*column_vector_length,
0);
>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507))

// for const 0, return const blank string.
if (0 == length)
Expand All @@ -1911,6 +2079,7 @@ class FunctionRightUTF8 : public IFunction
return true;
}

<<<<<<< HEAD
RightUTF8Impl::vectorConst(col_string->getChars(), col_string->getOffsets(), length, col_res->getChars(), col_res->getOffsets());
}
else
Expand Down Expand Up @@ -1942,6 +2111,61 @@ class FunctionRightUTF8 : public IFunction
block.getByPosition(result).column = std::move(col_res);
return true;
});
=======
RightUTF8Impl::vectorConst(
col_string->getChars(),
col_string->getOffsets(),
length,
col_res->getChars(),
col_res->getOffsets());
}
else
{
// vector vector
auto get_length_func = [column_vector_length](size_t i) {
return FunctionSubstringUTF8::getValueFromLengthColumn<LengthFieldType>(
*column_vector_length,
i);
};
RightUTF8Impl::vectorVector(
col_string->getChars(),
col_string->getOffsets(),
get_length_func,
col_res->getChars(),
col_res->getOffsets());
}
}
else if (
const ColumnConst * col_const_string = checkAndGetColumnConst<ColumnString>(column_string.get()))
{
// const vector
const auto * col_string_from_const
= checkAndGetColumn<ColumnString>(col_const_string->getDataColumnPtr().get());
assert(col_string_from_const);
// When useDefaultImplementationForConstants is true, string and length are not both constants
assert(!column_length->isColumnConst());
auto get_length_func = [column_vector_length](size_t i) {
return FunctionSubstringUTF8::getValueFromLengthColumn<LengthFieldType>(
*column_vector_length,
i);
};
RightUTF8Impl::constVector(
column_length->size(),
col_string_from_const->getChars(),
col_string_from_const->getOffsets(),
get_length_func,
col_res->getChars(),
col_res->getOffsets());
}
else
{
// Impossible to reach here
return false;
}
block.getByPosition(result).column = std::move(col_res);
return true;
});
>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507))

if (!is_length_type_valid)
throw Exception(fmt::format("2nd argument of function {} must have UInt/Int type.", getName()));
Expand All @@ -1953,6 +2177,7 @@ class FunctionRightUTF8 : public IFunction
getLengthType(DataTypePtr type, F && f)
{
return castTypeToEither<
<<<<<<< HEAD
DataTypeInt64,
DataTypeUInt64>(type.get(), std::forward<F>(f));
}
Expand All @@ -1970,6 +2195,16 @@ class FunctionRightUTF8 : public IFunction
static_assert(std::is_same_v<Integer, UInt64>);
return length_field.get<UInt64>();
}
=======
DataTypeUInt8,
DataTypeUInt16,
DataTypeUInt32,
DataTypeUInt64,
DataTypeInt8,
DataTypeInt16,
DataTypeInt32,
DataTypeInt64>(type.get(), std::forward<F>(f));
>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507))
}
};

Expand Down
Loading

0 comments on commit 23398ad

Please sign in to comment.