Skip to content

Commit

Permalink
Refactor struct Converter and reduce template parameters (#7915)
Browse files Browse the repository at this point in the history
Summary:
Refactor struct Converter in Conversions.h to remove the code
complexity from the template parameter TRUNCATE and LEGACY_CAST.
Wrap them into a single template parameter TPolicy as well as
any future additional parameter.

struct Converter holds the core logic of the CAST expression
between scalar/primitive types. Benchmarking shows ignorable
impact on perf.

Also refactor the case of ToKind being BOOLEAN, and move the logic
into a single template specialization. The logic used to spread
over two template specializations, for TRUNCATE true and false.

Pull Request resolved: #7915

Reviewed By: Yuhta

Differential Revision: D51946081

Pulled By: gggrace14

fbshipit-source-id: f69624d07d58ca1946c806414f76b8e35597a271
  • Loading branch information
gggrace14 authored and facebook-github-bot committed Jan 22, 2024
1 parent 5db9d84 commit 063f2b6
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 122 deletions.
9 changes: 7 additions & 2 deletions velox/experimental/codegen/functions/CastFunctionStub.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,16 @@ struct CodegenConversionStub {

template <typename T>
static ReturnType<T> cast(const T& arg) {
using CastPolicy = typename std::conditional<
castByTruncate,
util::TTruncateCastPolicy,
util::TDefaultCastPolicy>::type;

if constexpr (std::is_same_v<codegen::TempString<TempsAllocator>, T>) {
return util::Converter<kind, void, castByTruncate>::cast(
return util::Converter<kind, void, CastPolicy>::cast(
folly::StringPiece(arg.data(), arg.size()));
} else {
return util::Converter<kind, void, castByTruncate>::cast(arg);
return util::Converter<kind, void, CastPolicy>::cast(arg);
}
}
};
Expand Down
16 changes: 8 additions & 8 deletions velox/expression/CastExpr-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,11 @@ void CastExpr::applyToSelectedNoThrowLocal(
/// The per-row level Kernel
/// @tparam ToKind The cast target type
/// @tparam FromKind The expression type
/// @tparam TPolicy The policy used by the cast
/// @param row The index of the current row
/// @param input The input vector (of type FromKind)
/// @param result The output vector (of type ToKind)
template <TypeKind ToKind, TypeKind FromKind, bool Truncate, bool LegacyCast>
template <TypeKind ToKind, TypeKind FromKind, typename TPolicy>
void CastExpr::applyCastKernel(
vector_size_t row,
EvalCtx& context,
Expand Down Expand Up @@ -222,8 +223,7 @@ void CastExpr::applyCastKernel(
}
}

auto output = util::Converter<ToKind, void, Truncate, LegacyCast>::cast(
inputRowValue);
auto output = util::Converter<ToKind, void, TPolicy>::cast(inputRowValue);

if constexpr (
ToKind == TypeKind::VARCHAR || ToKind == TypeKind::VARBINARY) {
Expand Down Expand Up @@ -317,7 +317,7 @@ VectorPtr CastExpr::applyDecimalToFloatCast(
const auto simpleInput = input.as<SimpleVector<FromNativeType>>();
const auto scaleFactor = DecimalUtil::kPowersOfTen[precisionScale.second];
applyToSelectedNoThrowLocal(context, rows, result, [&](int row) {
auto output = util::Converter<ToKind, void, false, false>::cast(
auto output = util::Converter<ToKind, void, util::DefaultCastPolicy>::cast(
simpleInput->valueAt(row));
resultBuffer[row] = output / scaleFactor;
});
Expand Down Expand Up @@ -507,24 +507,24 @@ void CastExpr::applyCastPrimitives(
if (!hooks_->truncate()) {
if (!hooks_->legacy()) {
applyToSelectedNoThrowLocal(context, rows, result, [&](int row) {
applyCastKernel<ToKind, FromKind, false /*truncate*/, false /*legacy*/>(
applyCastKernel<ToKind, FromKind, util::DefaultCastPolicy>(
row, context, inputSimpleVector, resultFlatVector);
});
} else {
applyToSelectedNoThrowLocal(context, rows, result, [&](int row) {
applyCastKernel<ToKind, FromKind, false /*truncate*/, true /*legacy*/>(
applyCastKernel<ToKind, FromKind, util::LegacyCastPolicy>(
row, context, inputSimpleVector, resultFlatVector);
});
}
} else {
if (!hooks_->legacy()) {
applyToSelectedNoThrowLocal(context, rows, result, [&](int row) {
applyCastKernel<ToKind, FromKind, true /*truncate*/, false /*legacy*/>(
applyCastKernel<ToKind, FromKind, util::TruncateCastPolicy>(
row, context, inputSimpleVector, resultFlatVector);
});
} else {
applyToSelectedNoThrowLocal(context, rows, result, [&](int row) {
applyCastKernel<ToKind, FromKind, true /*truncate*/, true /*legacy*/>(
applyCastKernel<ToKind, FromKind, util::TruncateLegacyCastPolicy>(
row, context, inputSimpleVector, resultFlatVector);
});
}
Expand Down
3 changes: 2 additions & 1 deletion velox/expression/CastExpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,11 @@ class CastExpr : public SpecialForm {
/// The per-row level Kernel
/// @tparam ToKind The cast target type
/// @tparam FromKind The expression type
/// @tparam TPolicy The policy used by the cast
/// @param row The index of the current row
/// @param input The input vector (of type FromKind)
/// @param result The output vector (of type ToKind)
template <TypeKind ToKind, TypeKind FromKind, bool Truncate, bool LegacyCast>
template <TypeKind ToKind, TypeKind FromKind, typename TPolicy>
void applyCastKernel(
vector_size_t row,
EvalCtx& context,
Expand Down
8 changes: 4 additions & 4 deletions velox/expression/PrestoCastHooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ void PrestoCastHooks::castTimestampToString(
StringWriter<false>& out) const {
out.copy_from(
legacyCast_
? util::Converter<TypeKind::VARCHAR, void, false, true>::cast(
timestamp)
: util::Converter<TypeKind::VARCHAR, void, false, false>::cast(
timestamp));
? util::Converter<TypeKind::VARCHAR, void, util::LegacyCastPolicy>::
cast(timestamp)
: util::Converter<TypeKind::VARCHAR, void, util::DefaultCastPolicy>::
cast(timestamp));
out.finalize();
}

Expand Down
3 changes: 2 additions & 1 deletion velox/functions/prestosql/ArrayFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ struct ArrayJoinFunction {
template <typename C>
void writeValue(out_type<velox::Varchar>& result, const C& value) {
result +=
util::Converter<TypeKind::VARCHAR, void, false, false>::cast(value);
util::Converter<TypeKind::VARCHAR, void, util::DefaultCastPolicy>::cast(
value);
}

template <typename C>
Expand Down
Loading

0 comments on commit 063f2b6

Please sign in to comment.