Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support elt pushdown #5496

Merged
merged 20 commits into from
Aug 8, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dbms/src/Flash/Coprocessor/DAGUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ const std::unordered_map<tipb::ScalarFuncSig, String> scalar_func_map({
{tipb::ScalarFuncSig::Concat, "tidbConcat"},
{tipb::ScalarFuncSig::ConcatWS, "tidbConcatWS"},
//{tipb::ScalarFuncSig::Convert, "cast"},
//{tipb::ScalarFuncSig::Elt, "cast"},
{tipb::ScalarFuncSig::Elt, "elt"},
//{tipb::ScalarFuncSig::ExportSet3Arg, "cast"},
//{tipb::ScalarFuncSig::ExportSet4Arg, "cast"},
//{tipb::ScalarFuncSig::ExportSet5Arg, "cast"},
Expand Down
12 changes: 11 additions & 1 deletion dbms/src/Functions/FunctionHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,24 @@ bool checkColumn(const IColumn * column)
return checkAndGetColumn<Type>(column);
}

template <typename Type>
const Type * checkAndGetNestedColumn(const IColumn * column)
{
if (!column || !column->isColumnNullable())
return {};

const auto * data_column = &static_cast<const ColumnNullable *>(column)->getNestedColumn();

return checkAndGetColumn<Type>(data_column);
}

template <typename Type>
const ColumnConst * checkAndGetColumnConst(const IColumn * column, bool maybe_nullable_column = false)
{
if (!column || !column->isColumnConst())
return {};

const ColumnConst * res = static_cast<const ColumnConst *>(column);
const auto * res = static_cast<const ColumnConst *>(column);

const auto * data_column = &res->getDataColumn();
if (maybe_nullable_column && data_column->isColumnNullable())
Expand Down
172 changes: 168 additions & 4 deletions dbms/src/Functions/FunctionsString.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,11 @@
#include <Functions/GatherUtils/GatherUtils.h>
#include <Functions/StringUtil.h>
#include <Functions/castTypeToEither.h>
#include <IO/WriteHelpers.h>
#include <Interpreters/Context.h>
#include <fmt/core.h>
#include <fmt/format.h>
#include <fmt/printf.h>

#include <boost/algorithm/string/predicate.hpp>
#include <ext/range.h>
#include <thread>

namespace DB
{
Expand Down Expand Up @@ -5415,8 +5411,175 @@ class FunctionBin : public IFunction
throw Exception(fmt::format("Illegal argument of function {}", getName()), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
}
};

class FunctionElt : public IFunction
{
public:
static constexpr auto name = "elt";

static FunctionPtr create(const Context & /*context*/)
{
return std::make_shared<FunctionElt>();
}

String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 0; }
bool isVariadic() const override { return true; }

bool useDefaultImplementationForNulls() const override { return false; }
bool useDefaultImplementationForConstants() const override { return true; }

DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (arguments.size() < 2)
throw Exception(
fmt::format("Number of arguments for function {} doesn't match: passed {}, should be at least 2.", getName(), arguments.size()),
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);

auto first_argument = removeNullable(arguments[0]);
if (!first_argument->isInteger())
throw Exception(
fmt::format("Illegal type {} of first argument of function {}", first_argument->getName(), getName()),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

for (const auto arg_idx : ext::range(1, arguments.size()))
{
const auto arg = removeNullable(arguments[arg_idx]);
if (!arg->isString())
throw Exception(
fmt::format("Illegal type {} of argument {} of function {}", arg->getName(), arg_idx + 1, getName()),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}

return makeNullable(std::make_shared<DataTypeString>());
}

void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) const override
{
if (executeElt<UInt8>(block, arguments, result)
|| executeElt<UInt16>(block, arguments, result)
|| executeElt<UInt32>(block, arguments, result)
|| executeElt<UInt64>(block, arguments, result)
|| executeElt<Int8>(block, arguments, result)
|| executeElt<Int16>(block, arguments, result)
|| executeElt<Int32>(block, arguments, result)
|| executeElt<Int64>(block, arguments, result))
{
return;
}
else
{
throw Exception(fmt::format("Illegal argument of function {}", getName()), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
}

private:
template <typename IntType>
static bool executeElt(Block & block, const ColumnNumbers & arguments, size_t result)
{
const auto * col_idx = block.getByPosition(arguments[0]).column.get();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This variable's name may be inappropriate

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I will use arg0 instead


if (const auto * col = checkAndGetColumnConst<ColumnVector<IntType>>(col_idx, true))
{
return constColumn<IntType>(col, block, arguments, result);
}
else
{
return vectorColumn<IntType>(col_idx, block, arguments, result);
}
}

static void fillResultColumnNull(Block & block, size_t result)
{
block.getByPosition(result).column = block.getByPosition(result).type->createColumnConst(block.rows(), Null());
}

template <typename IntType>
static bool constColumn(const ColumnConst * col, Block & block, const ColumnNumbers & arguments, size_t result)
{
const auto nrow = col->size();

if (col->onlyNull())
{
fillResultColumnNull(block, result);
return true;
}

const auto idx = col->getDataColumnPtr()->isColumnNullable()
? checkAndGetNestedColumn<ColumnVector<IntType>>(col->getDataColumnPtr().get())->getInt(0)
: col->getInt(0);

if (idx < 1 || idx >= static_cast<Int64>(arguments.size()))
{
fillResultColumnNull(block, result);
}
else
{
block.getByPosition(result).column = block.getByPosition(arguments[idx]).column->cloneResized(nrow);
}
return true;
}

template <typename IntType>
static bool vectorColumn(const IColumn * col, Block & block, const ColumnNumbers & arguments, size_t result)
{
const auto narg = arguments.size();
const auto nrow = col->size();
const auto col_idx = col->isColumnNullable()
? checkAndGetNestedColumn<ColumnVector<IntType>>(col)
: checkAndGetColumn<ColumnVector<IntType>>(col);

if (!col_idx)
{
return false;
}

const auto & idx_vec = col_idx->getData();

auto res_null_map = ColumnUInt8::create(nrow);
auto res_col = ColumnString::create();

for (size_t i = 0; i < nrow; ++i)
{
const auto idx = idx_vec[i];

if (col_idx->isNullAt(i) || idx < 1 || static_cast<Int64>(idx) >= static_cast<Int64>(narg))
{
res_null_map->getData()[i] = true;
res_col->insertDefault();
}
else
{
const auto arg_pos = arguments[idx];
const auto src_col = block.getByPosition(arg_pos).column.get();

if (src_col->isNullAt(i))
{
res_null_map->getData()[i] = true;
res_col->insertDefault();
}
else
{
res_null_map->getData()[i] = false;

const auto col_str = src_col->isColumnNullable()
? checkAndGetNestedColumn<ColumnString>(src_col)
: checkAndGetColumn<ColumnString>(src_col);
const auto & col_data = col_str->getChars();
const auto & col_offsets = col_str->getOffsets();

const auto start_offset = StringUtil::offsetAt(col_offsets, i);
const auto str_size = StringUtil::sizeAt(col_offsets, i);

res_col->insertDataWithTerminatingZero(reinterpret_cast<const char *>(&col_data[start_offset]), str_size);
}
}
}

block.getByPosition(result).column = ColumnNullable::create(std::move(res_col), std::move(res_null_map));
return true;
}
};

// clang-format off
Expand Down Expand Up @@ -5507,5 +5670,6 @@ void registerFunctionsString(FunctionFactory & factory)
factory.registerFunction<FunctionHexInt>();
factory.registerFunction<FunctionRepeat>();
factory.registerFunction<FunctionBin>();
factory.registerFunction<FunctionElt>();
}
} // namespace DB
Loading