Skip to content

Commit

Permalink
This is an automated cherry-pick of #9615
Browse files Browse the repository at this point in the history
Signed-off-by: ti-chi-bot <ti-community-prow-bot@tidb.io>
  • Loading branch information
guo-shaoge authored and ti-chi-bot committed Nov 18, 2024
1 parent ba9fab8 commit ee676e1
Show file tree
Hide file tree
Showing 5 changed files with 347 additions and 61 deletions.
228 changes: 191 additions & 37 deletions dbms/src/Functions/FunctionsStringReplace.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/GatherUtils/Algorithms.h>
#include <Functions/GatherUtils/Sources.h>
#include <Functions/IFunction.h>

namespace DB
Expand All @@ -47,33 +49,17 @@ class FunctionStringReplace : public IFunction
return name;
}

<<<<<<< HEAD
size_t getNumberOfArguments() const override
{
return 0;
}
=======
size_t getNumberOfArguments() const override { return 3; }
>>>>>>> 11ce13fffa (fix error when first argument of replace function is const (#9615))

bool isVariadic() const override { return true; }
bool isVariadic() const override { return false; }
bool useDefaultImplementationForConstants() const override { return true; }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override
{
if constexpr (Impl::support_non_const_needle && Impl::support_non_const_replacement)
{
return {3, 4, 5};
}
else if constexpr (Impl::support_non_const_needle)
{
return {2, 3, 4, 5};
}
else if constexpr (Impl::support_non_const_replacement)
{
return {1, 3, 4, 5};
}
else
{
return {1, 2, 3, 4, 5};
}
}
void setCollator(const TiDB::TiDBCollatorPtr & collator_) override { collator = collator_; }

DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
Expand All @@ -89,6 +75,7 @@ class FunctionStringReplace : public IFunction
throw Exception("Illegal type " + arguments[2]->getName() + " of third argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

<<<<<<< HEAD
if (arguments.size() > 3 && !arguments[3]->isInteger())
throw Exception("Illegal type " + arguments[2]->getName() + " of forth argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
Expand All @@ -101,11 +88,14 @@ class FunctionStringReplace : public IFunction
throw Exception("Illegal type " + arguments[2]->getName() + " of sixth argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

=======
>>>>>>> 11ce13fffa (fix error when first argument of replace function is const (#9615))
return std::make_shared<DataTypeString>();
}

void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) const override
{
<<<<<<< HEAD
const ColumnPtr & column_src = block.getByPosition(arguments[0]).column;
const ColumnPtr & column_needle = block.getByPosition(arguments[1]).column;
const ColumnPtr & column_replacement = block.getByPosition(arguments[2]).column;
Expand All @@ -120,18 +110,34 @@ class FunctionStringReplace : public IFunction
Int64 pos = column_pos == nullptr ? 1 : typeid_cast<const ColumnConst *>(column_pos.get())->getInt(0);
Int64 occ = column_occ == nullptr ? 0 : typeid_cast<const ColumnConst *>(column_occ.get())->getInt(0);
String match_type = column_match_type == nullptr ? "" : typeid_cast<const ColumnConst *>(column_match_type.get())->getValue<String>();
=======
ColumnPtr column_src = block.getByPosition(arguments[0]).column;
ColumnPtr column_needle = block.getByPosition(arguments[1]).column;
ColumnPtr column_replacement = block.getByPosition(arguments[2]).column;
>>>>>>> 11ce13fffa (fix error when first argument of replace function is const (#9615))

ColumnWithTypeAndName & column_result = block.getByPosition(result);

bool needle_const = column_needle->isColumnConst();
bool replacement_const = column_replacement->isColumnConst();

if (needle_const && replacement_const)
if (column_src->isColumnConst())
{
executeImplConstHaystack(
column_src,
column_needle,
column_replacement,
needle_const,
replacement_const,
column_result);
}
else if (needle_const && replacement_const)
{
executeImpl(column_src, column_needle, column_replacement, pos, occ, match_type, column_result);
executeImpl(column_src, column_needle, column_replacement, column_result);
}
else if (needle_const)
{
<<<<<<< HEAD
executeImplNonConstReplacement(column_src, column_needle, column_replacement, pos, occ, match_type, column_result);
}
else if (replacement_const)
Expand All @@ -141,6 +147,17 @@ class FunctionStringReplace : public IFunction
else
{
executeImplNonConstNeedleReplacement(column_src, column_needle, column_replacement, pos, occ, match_type, column_result);
=======
executeImplNonConstReplacement(column_src, column_needle, column_replacement, column_result);
}
else if (replacement_const)
{
executeImplNonConstNeedle(column_src, column_needle, column_replacement, column_result);
}
else
{
executeImplNonConstNeedleReplacement(column_src, column_needle, column_replacement, column_result);
>>>>>>> 11ce13fffa (fix error when first argument of replace function is const (#9615))
}
}

Expand All @@ -149,9 +166,6 @@ class FunctionStringReplace : public IFunction
const ColumnPtr & column_src,
const ColumnPtr & column_needle,
const ColumnPtr & column_replacement,
Int64 pos,
Int64 occ,
const String & match_type,
ColumnWithTypeAndName & column_result) const
{
const auto * c1_const = typeid_cast<const ColumnConst *>(column_needle.get());
Expand All @@ -162,13 +176,33 @@ class FunctionStringReplace : public IFunction
if (const auto * col = checkAndGetColumn<ColumnString>(column_src.get()))
{
auto col_res = ColumnString::create();
<<<<<<< HEAD
Impl::vector(col->getChars(), col->getOffsets(), needle, replacement, pos, occ, match_type, collator, col_res->getChars(), col_res->getOffsets());
=======
Impl::vector(
col->getChars(),
col->getOffsets(),
needle,
replacement,
col_res->getChars(),
col_res->getOffsets());
>>>>>>> 11ce13fffa (fix error when first argument of replace function is const (#9615))
column_result.column = std::move(col_res);
}
else if (const auto * col = checkAndGetColumn<ColumnFixedString>(column_src.get()))
{
auto col_res = ColumnString::create();
<<<<<<< HEAD
Impl::vectorFixed(col->getChars(), col->getN(), needle, replacement, pos, occ, match_type, collator, col_res->getChars(), col_res->getOffsets());
=======
Impl::vectorFixed(
col->getChars(),
col->getN(),
needle,
replacement,
col_res->getChars(),
col_res->getOffsets());
>>>>>>> 11ce13fffa (fix error when first argument of replace function is const (#9615))
column_result.column = std::move(col_res);
}
else
Expand All @@ -177,13 +211,73 @@ class FunctionStringReplace : public IFunction
ErrorCodes::ILLEGAL_COLUMN);
}

void executeImplConstHaystack(
const ColumnPtr & column_src,
const ColumnPtr & column_needle,
const ColumnPtr & column_replacement,
bool needle_const,
bool replacement_const,
ColumnWithTypeAndName & column_result) const
{
auto res_col = ColumnString::create();
res_col->reserve(column_src->size());

RUNTIME_CHECK_MSG(
!needle_const || !replacement_const,
"should not got here when all argments of replace are constant");

const auto * column_src_const = checkAndGetColumnConst<ColumnString>(column_src.get());
RUNTIME_CHECK(column_src_const);

using GatherUtils::ConstSource;
using GatherUtils::StringSource;
if (!needle_const && !replacement_const)
{
const auto * column_needle_string = checkAndGetColumn<ColumnString>(column_needle.get());
const auto * column_replacement_string = checkAndGetColumn<ColumnString>(column_replacement.get());
RUNTIME_CHECK(column_needle_string);
RUNTIME_CHECK(column_replacement_string);

GatherUtils::replace<Impl>(
ConstSource<StringSource>(*column_src_const),
StringSource(*column_needle_string),
StringSource(*column_replacement_string),
res_col);
}
else if (needle_const && !replacement_const)
{
const auto * column_needle_const = checkAndGetColumnConst<ColumnString>(column_needle.get());
const auto * column_replacement_string = checkAndGetColumn<ColumnString>(column_replacement.get());
RUNTIME_CHECK(column_needle_const);
RUNTIME_CHECK(column_replacement_string);

GatherUtils::replace<Impl>(
ConstSource<StringSource>(*column_src_const),
ConstSource<StringSource>(*column_needle_const),
StringSource(*column_replacement_string),
res_col);
}
else if (!needle_const && replacement_const)
{
const auto * column_needle_string = checkAndGetColumn<ColumnString>(column_needle.get());
const auto * column_replacement_const = checkAndGetColumnConst<ColumnString>(column_replacement.get());
RUNTIME_CHECK(column_needle_string);
RUNTIME_CHECK(column_replacement_const);

GatherUtils::replace<Impl>(
ConstSource<StringSource>(*column_src_const),
StringSource(*column_needle_string),
ConstSource<StringSource>(*column_replacement_const),
res_col);
}

column_result.column = std::move(res_col);
}

void executeImplNonConstNeedle(
const ColumnPtr & column_src,
const ColumnPtr & column_needle,
const ColumnPtr & column_replacement,
Int64 pos [[maybe_unused]],
Int64 occ [[maybe_unused]],
const String & match_type,
ColumnWithTypeAndName & column_result) const
{
if constexpr (Impl::support_non_const_needle)
Expand All @@ -195,13 +289,35 @@ class FunctionStringReplace : public IFunction
if (const auto * col = checkAndGetColumn<ColumnString>(column_src.get()))
{
auto col_res = ColumnString::create();
<<<<<<< HEAD
Impl::vectorNonConstNeedle(col->getChars(), col->getOffsets(), col_needle->getChars(), col_needle->getOffsets(), replacement, pos, occ, match_type, collator, col_res->getChars(), col_res->getOffsets());
=======
Impl::vectorNonConstNeedle(
col->getChars(),
col->getOffsets(),
col_needle->getChars(),
col_needle->getOffsets(),
replacement,
col_res->getChars(),
col_res->getOffsets());
>>>>>>> 11ce13fffa (fix error when first argument of replace function is const (#9615))
column_result.column = std::move(col_res);
}
else if (const auto * col = checkAndGetColumn<ColumnFixedString>(column_src.get()))
{
auto col_res = ColumnString::create();
<<<<<<< HEAD
Impl::vectorFixedNonConstNeedle(col->getChars(), col->getN(), col_needle->getChars(), col_needle->getOffsets(), replacement, pos, occ, match_type, collator, col_res->getChars(), col_res->getOffsets());
=======
Impl::vectorFixedNonConstNeedle(
col->getChars(),
col->getN(),
col_needle->getChars(),
col_needle->getOffsets(),
replacement,
col_res->getChars(),
col_res->getOffsets());
>>>>>>> 11ce13fffa (fix error when first argument of replace function is const (#9615))
column_result.column = std::move(col_res);
}
else
Expand All @@ -219,9 +335,6 @@ class FunctionStringReplace : public IFunction
const ColumnPtr & column_src,
const ColumnPtr & column_needle,
const ColumnPtr & column_replacement,
Int64 pos [[maybe_unused]],
Int64 occ [[maybe_unused]],
const String & match_type,
ColumnWithTypeAndName & column_result) const
{
if constexpr (Impl::support_non_const_replacement)
Expand All @@ -233,13 +346,35 @@ class FunctionStringReplace : public IFunction
if (const auto * col = checkAndGetColumn<ColumnString>(column_src.get()))
{
auto col_res = ColumnString::create();
<<<<<<< HEAD
Impl::vectorNonConstReplacement(col->getChars(), col->getOffsets(), needle, col_replacement->getChars(), col_replacement->getOffsets(), pos, occ, match_type, collator, col_res->getChars(), col_res->getOffsets());
=======
Impl::vectorNonConstReplacement(
col->getChars(),
col->getOffsets(),
needle,
col_replacement->getChars(),
col_replacement->getOffsets(),
col_res->getChars(),
col_res->getOffsets());
>>>>>>> 11ce13fffa (fix error when first argument of replace function is const (#9615))
column_result.column = std::move(col_res);
}
else if (const auto * col = checkAndGetColumn<ColumnFixedString>(column_src.get()))
{
auto col_res = ColumnString::create();
<<<<<<< HEAD
Impl::vectorFixedNonConstReplacement(col->getChars(), col->getN(), needle, col_replacement->getChars(), col_replacement->getOffsets(), pos, occ, match_type, collator, col_res->getChars(), col_res->getOffsets());
=======
Impl::vectorFixedNonConstReplacement(
col->getChars(),
col->getN(),
needle,
col_replacement->getChars(),
col_replacement->getOffsets(),
col_res->getChars(),
col_res->getOffsets());
>>>>>>> 11ce13fffa (fix error when first argument of replace function is const (#9615))
column_result.column = std::move(col_res);
}
else
Expand All @@ -257,9 +392,6 @@ class FunctionStringReplace : public IFunction
const ColumnPtr & column_src,
const ColumnPtr & column_needle,
const ColumnPtr & column_replacement,
Int64 pos [[maybe_unused]],
Int64 occ [[maybe_unused]],
const String & match_type,
ColumnWithTypeAndName & column_result) const
{
if constexpr (Impl::support_non_const_needle && Impl::support_non_const_replacement)
Expand All @@ -270,13 +402,37 @@ class FunctionStringReplace : public IFunction
if (const auto * col = checkAndGetColumn<ColumnString>(column_src.get()))
{
auto col_res = ColumnString::create();
<<<<<<< HEAD
Impl::vectorNonConstNeedleReplacement(col->getChars(), col->getOffsets(), col_needle->getChars(), col_needle->getOffsets(), col_replacement->getChars(), col_replacement->getOffsets(), pos, occ, match_type, collator, col_res->getChars(), col_res->getOffsets());
=======
Impl::vectorNonConstNeedleReplacement(
col->getChars(),
col->getOffsets(),
col_needle->getChars(),
col_needle->getOffsets(),
col_replacement->getChars(),
col_replacement->getOffsets(),
col_res->getChars(),
col_res->getOffsets());
>>>>>>> 11ce13fffa (fix error when first argument of replace function is const (#9615))
column_result.column = std::move(col_res);
}
else if (const auto * col = checkAndGetColumn<ColumnFixedString>(column_src.get()))
{
auto col_res = ColumnString::create();
<<<<<<< HEAD
Impl::vectorFixedNonConstNeedleReplacement(col->getChars(), col->getN(), col_needle->getChars(), col_needle->getOffsets(), col_replacement->getChars(), col_replacement->getOffsets(), pos, occ, match_type, collator, col_res->getChars(), col_res->getOffsets());
=======
Impl::vectorFixedNonConstNeedleReplacement(
col->getChars(),
col->getN(),
col_needle->getChars(),
col_needle->getOffsets(),
col_replacement->getChars(),
col_replacement->getOffsets(),
col_res->getChars(),
col_res->getOffsets());
>>>>>>> 11ce13fffa (fix error when first argument of replace function is const (#9615))
column_result.column = std::move(col_res);
}
else
Expand All @@ -289,7 +445,5 @@ class FunctionStringReplace : public IFunction
throw Exception("Argument at index 2 and 3 for function replace must be constant", ErrorCodes::ILLEGAL_COLUMN);
}
}

TiDB::TiDBCollatorPtr collator{};
};
} // namespace DB
Loading

0 comments on commit ee676e1

Please sign in to comment.