Skip to content

Commit

Permalink
GH-43956: [C++][Compute] Add Decimal32/64 Casts (#45014)
Browse files Browse the repository at this point in the history
<!--
Thanks for opening a pull request!
If this is your first pull request you can find detailed information on
how
to contribute here:
* [New Contributor's
Guide](https://arrow.apache.org/docs/dev/developers/guide/step_by_step/pr_lifecycle.html#reviews-and-merge-of-the-pull-request)
* [Contributing
Overview](https://arrow.apache.org/docs/dev/developers/overview.html)


If this is not a [minor
PR](https://github.com/apache/arrow/blob/main/CONTRIBUTING.md#Minor-Fixes).
Could you open an issue for this pull request on GitHub?
https://github.com/apache/arrow/issues/new/choose

Opening GitHub issues ahead of time contributes to the
[Openness](http://theapacheway.com/open/#:~:text=Openness%20allows%20new%20users%20the,must%20happen%20in%20the%20open.)
of the Apache Arrow project.

Then could you also rename the pull request title in the following
format?

    GH-${GITHUB_ISSUE_ID}: [${COMPONENT}] ${SUMMARY}

or

    MINOR: [${COMPONENT}] ${SUMMARY}

-->

### Rationale for this change
Furthering the support for Decimal32/Decimal64 among Acero and casting
functionality. This is also necessary for #44882 to add support for
Decimal32/64 to PyArrow

<!--
Why are you proposing this change? If this is already explained clearly
in the issue then this section is not needed.
Explaining clearly why changes are proposed helps reviewers understand
your changes and offer better suggestions for fixes.
-->

### What changes are included in this PR?
Adding kernels for casting to and from Decimal32/Decimal64 between
numeric, floating point, string and other decimal types.

<!--
There is no need to duplicate the description in the issue here but it
is sometimes worth providing a summary of the individual changes in this
PR.
-->

### Are these changes tested?
Yes, unit tests are added accordingly.

<!--
We typically require tests for all PRs in order to:
1. Prevent the code from being accidentally broken by subsequent changes
2. Serve as another way to document the expected behavior of the code

If tests are not included in your PR, please explain why (for example,
are they covered by existing tests)?
-->

* GitHub Issue: #43956
  • Loading branch information
zeroshade authored Dec 13, 2024
1 parent f9a6eda commit 313d11a
Show file tree
Hide file tree
Showing 4 changed files with 1,391 additions and 31 deletions.
177 changes: 175 additions & 2 deletions cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,43 @@ struct DecimalConversions<Decimal256, InDecimal> {
static Decimal256 ConvertOutput(Decimal256&& val) { return val; }
};

template <typename InDecimal>
struct DecimalConversions<Decimal32, InDecimal> {
static Decimal32 ConvertInput(InDecimal&& val) { return Decimal32(val.low_bits()); }
static Decimal32 ConvertOutput(Decimal32&& val) { return val; }
};

template <>
struct DecimalConversions<Decimal64, Decimal32> {
// Convert then scale
static Decimal64 ConvertInput(Decimal32&& val) { return Decimal64(val); }
static Decimal64 ConvertOutput(Decimal64&& val) { return val; }
};

template <>
struct DecimalConversions<Decimal64, Decimal64> {
static Decimal64 ConvertInput(Decimal64&& val) { return val; }
static Decimal64 ConvertOutput(Decimal64&& val) { return val; }
};

template <>
struct DecimalConversions<Decimal64, Decimal128> {
// Scale then truncate
static Decimal128 ConvertInput(Decimal128&& val) { return val; }
static Decimal64 ConvertOutput(Decimal128&& val) {
return Decimal64(static_cast<int64_t>(val.low_bits()));
}
};

template <>
struct DecimalConversions<Decimal64, Decimal256> {
// Scale then truncate
static Decimal256 ConvertInput(Decimal256&& val) { return val; }
static Decimal64 ConvertOutput(Decimal256&& val) {
return Decimal64(static_cast<int64_t>(val.low_bits()));
}
};

template <>
struct DecimalConversions<Decimal128, Decimal256> {
// Scale then truncate
Expand All @@ -495,6 +532,20 @@ struct DecimalConversions<Decimal128, Decimal128> {
static Decimal128 ConvertOutput(Decimal128&& val) { return val; }
};

template <>
struct DecimalConversions<Decimal128, Decimal64> {
// convert then scale
static Decimal128 ConvertInput(Decimal64&& val) { return Decimal128(val.value()); }
static Decimal128 ConvertOutput(Decimal128&& val) { return val; }
};

template <>
struct DecimalConversions<Decimal128, Decimal32> {
// convert then scale
static Decimal128 ConvertInput(Decimal32&& val) { return Decimal128(val.value()); }
static Decimal128 ConvertOutput(Decimal128&& val) { return val; }
};

struct UnsafeUpscaleDecimal {
template <typename OutValue, typename Arg0Value>
OutValue Call(KernelContext*, Arg0Value val, Status*) const {
Expand Down Expand Up @@ -659,6 +710,18 @@ struct DecimalCastFunctor {
}
};

template <typename I>
struct CastFunctor<
Decimal32Type, I,
enable_if_t<is_base_binary_type<I>::value || is_binary_view_like_type<I>::value>>
: public DecimalCastFunctor<Decimal32Type, I> {};

template <typename I>
struct CastFunctor<
Decimal64Type, I,
enable_if_t<is_base_binary_type<I>::value || is_binary_view_like_type<I>::value>>
: public DecimalCastFunctor<Decimal64Type, I> {};

template <typename I>
struct CastFunctor<
Decimal128Type, I,
Expand Down Expand Up @@ -744,6 +807,10 @@ std::shared_ptr<CastFunction> GetCastToInteger(std::string name) {
// From decimal to integer
DCHECK_OK(func->AddKernel(Type::DECIMAL, {InputType(Type::DECIMAL)}, out_ty,
CastFunctor<OutType, Decimal128Type>::Exec));
DCHECK_OK(func->AddKernel(Type::DECIMAL32, {InputType(Type::DECIMAL32)}, out_ty,
CastFunctor<OutType, Decimal32Type>::Exec));
DCHECK_OK(func->AddKernel(Type::DECIMAL64, {InputType(Type::DECIMAL64)}, out_ty,
CastFunctor<OutType, Decimal64Type>::Exec));
DCHECK_OK(func->AddKernel(Type::DECIMAL256, {InputType(Type::DECIMAL256)}, out_ty,
CastFunctor<OutType, Decimal256Type>::Exec));
return func;
Expand Down Expand Up @@ -772,6 +839,10 @@ std::shared_ptr<CastFunction> GetCastToFloating(std::string name) {
AddCommonNumberCasts<OutType>(out_ty, func.get());

// From decimal to floating point
DCHECK_OK(func->AddKernel(Type::DECIMAL32, {InputType(Type::DECIMAL32)}, out_ty,
CastFunctor<OutType, Decimal32Type>::Exec));
DCHECK_OK(func->AddKernel(Type::DECIMAL64, {InputType(Type::DECIMAL64)}, out_ty,
CastFunctor<OutType, Decimal64Type>::Exec));
DCHECK_OK(func->AddKernel(Type::DECIMAL, {InputType(Type::DECIMAL)}, out_ty,
CastFunctor<OutType, Decimal128Type>::Exec));
DCHECK_OK(func->AddKernel(Type::DECIMAL256, {InputType(Type::DECIMAL256)}, out_ty,
Expand All @@ -780,6 +851,94 @@ std::shared_ptr<CastFunction> GetCastToFloating(std::string name) {
return func;
}

std::shared_ptr<CastFunction> GetCastToDecimal32() {
OutputType sig_out_ty(ResolveOutputFromOptions);

auto func = std::make_shared<CastFunction>("cast_decimal32", Type::DECIMAL32);
AddCommonCasts(Type::DECIMAL32, sig_out_ty, func.get());

// Cast from floating point
DCHECK_OK(func->AddKernel(Type::FLOAT, {float32()}, sig_out_ty,
CastFunctor<Decimal32Type, FloatType>::Exec));
DCHECK_OK(func->AddKernel(Type::DOUBLE, {float64()}, sig_out_ty,
CastFunctor<Decimal32Type, DoubleType>::Exec));

// Cast from integer
for (const std::shared_ptr<DataType>& in_ty : IntTypes()) {
auto exec = GenerateInteger<CastFunctor, Decimal32Type>(in_ty->id());
DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, sig_out_ty, std::move(exec)));
}

// Cast from other strings
for (const std::shared_ptr<DataType>& in_ty : BaseBinaryTypes()) {
auto exec = GenerateVarBinaryBase<CastFunctor, Decimal32Type>(in_ty->id());
DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, sig_out_ty, std::move(exec)));
}
for (const std::shared_ptr<DataType>& in_ty : BinaryViewTypes()) {
auto exec = GenerateVarBinaryViewBase<CastFunctor, Decimal32Type>(in_ty->id());
DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, sig_out_ty, std::move(exec)));
}

// Cast from other decimal
auto exec = CastFunctor<Decimal32Type, Decimal32Type>::Exec;
DCHECK_OK(
func->AddKernel(Type::DECIMAL32, {InputType(Type::DECIMAL32)}, sig_out_ty, exec));
exec = CastFunctor<Decimal32Type, Decimal64Type>::Exec;
DCHECK_OK(
func->AddKernel(Type::DECIMAL64, {InputType(Type::DECIMAL64)}, sig_out_ty, exec));
exec = CastFunctor<Decimal32Type, Decimal128Type>::Exec;
DCHECK_OK(
func->AddKernel(Type::DECIMAL128, {InputType(Type::DECIMAL128)}, sig_out_ty, exec));
exec = CastFunctor<Decimal32Type, Decimal256Type>::Exec;
DCHECK_OK(
func->AddKernel(Type::DECIMAL256, {InputType(Type::DECIMAL256)}, sig_out_ty, exec));
return func;
}

std::shared_ptr<CastFunction> GetCastToDecimal64() {
OutputType sig_out_ty(ResolveOutputFromOptions);

auto func = std::make_shared<CastFunction>("cast_decimal64", Type::DECIMAL64);
AddCommonCasts(Type::DECIMAL64, sig_out_ty, func.get());

// Cast from floating point
DCHECK_OK(func->AddKernel(Type::FLOAT, {float32()}, sig_out_ty,
CastFunctor<Decimal64Type, FloatType>::Exec));
DCHECK_OK(func->AddKernel(Type::DOUBLE, {float64()}, sig_out_ty,
CastFunctor<Decimal64Type, DoubleType>::Exec));

// Cast from integer
for (const std::shared_ptr<DataType>& in_ty : IntTypes()) {
auto exec = GenerateInteger<CastFunctor, Decimal64Type>(in_ty->id());
DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, sig_out_ty, std::move(exec)));
}

// Cast from other strings
for (const std::shared_ptr<DataType>& in_ty : BaseBinaryTypes()) {
auto exec = GenerateVarBinaryBase<CastFunctor, Decimal64Type>(in_ty->id());
DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, sig_out_ty, std::move(exec)));
}
for (const std::shared_ptr<DataType>& in_ty : BinaryViewTypes()) {
auto exec = GenerateVarBinaryViewBase<CastFunctor, Decimal64Type>(in_ty->id());
DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, sig_out_ty, std::move(exec)));
}

// Cast from other decimal
auto exec = CastFunctor<Decimal64Type, Decimal32Type>::Exec;
DCHECK_OK(
func->AddKernel(Type::DECIMAL32, {InputType(Type::DECIMAL32)}, sig_out_ty, exec));
exec = CastFunctor<Decimal64Type, Decimal64Type>::Exec;
DCHECK_OK(
func->AddKernel(Type::DECIMAL64, {InputType(Type::DECIMAL64)}, sig_out_ty, exec));
exec = CastFunctor<Decimal64Type, Decimal128Type>::Exec;
DCHECK_OK(
func->AddKernel(Type::DECIMAL128, {InputType(Type::DECIMAL128)}, sig_out_ty, exec));
exec = CastFunctor<Decimal64Type, Decimal256Type>::Exec;
DCHECK_OK(
func->AddKernel(Type::DECIMAL256, {InputType(Type::DECIMAL256)}, sig_out_ty, exec));
return func;
}

std::shared_ptr<CastFunction> GetCastToDecimal128() {
OutputType sig_out_ty(ResolveOutputFromOptions);

Expand Down Expand Up @@ -809,8 +968,14 @@ std::shared_ptr<CastFunction> GetCastToDecimal128() {
}

// Cast from other decimal
auto exec = CastFunctor<Decimal128Type, Decimal128Type>::Exec;
auto exec = CastFunctor<Decimal128Type, Decimal32Type>::Exec;
// We resolve the output type of this kernel from the CastOptions
DCHECK_OK(
func->AddKernel(Type::DECIMAL32, {InputType(Type::DECIMAL32)}, sig_out_ty, exec));
exec = CastFunctor<Decimal128Type, Decimal64Type>::Exec;
DCHECK_OK(
func->AddKernel(Type::DECIMAL64, {InputType(Type::DECIMAL64)}, sig_out_ty, exec));
exec = CastFunctor<Decimal128Type, Decimal128Type>::Exec;
DCHECK_OK(
func->AddKernel(Type::DECIMAL128, {InputType(Type::DECIMAL128)}, sig_out_ty, exec));
exec = CastFunctor<Decimal128Type, Decimal256Type>::Exec;
Expand Down Expand Up @@ -848,7 +1013,13 @@ std::shared_ptr<CastFunction> GetCastToDecimal256() {
}

// Cast from other decimal
auto exec = CastFunctor<Decimal256Type, Decimal128Type>::Exec;
auto exec = CastFunctor<Decimal256Type, Decimal32Type>::Exec;
DCHECK_OK(
func->AddKernel(Type::DECIMAL32, {InputType(Type::DECIMAL32)}, sig_out_ty, exec));
exec = CastFunctor<Decimal256Type, Decimal64Type>::Exec;
DCHECK_OK(
func->AddKernel(Type::DECIMAL64, {InputType(Type::DECIMAL64)}, sig_out_ty, exec));
exec = CastFunctor<Decimal256Type, Decimal128Type>::Exec;
DCHECK_OK(
func->AddKernel(Type::DECIMAL128, {InputType(Type::DECIMAL128)}, sig_out_ty, exec));
exec = CastFunctor<Decimal256Type, Decimal256Type>::Exec;
Expand Down Expand Up @@ -950,6 +1121,8 @@ std::vector<std::shared_ptr<CastFunction>> GetNumericCasts() {
auto cast_double = GetCastToFloating<DoubleType>("cast_double");
functions.push_back(cast_double);

functions.push_back(GetCastToDecimal32());
functions.push_back(GetCastToDecimal64());
functions.push_back(GetCastToDecimal128());
functions.push_back(GetCastToDecimal256());

Expand Down
3 changes: 2 additions & 1 deletion cpp/src/arrow/compute/kernels/scalar_cast_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,8 @@ void AddNumberToStringCasts(CastFunction* func) {
template <typename OutType>
void AddDecimalToStringCasts(CastFunction* func) {
auto out_ty = TypeTraits<OutType>::type_singleton();
for (const auto& in_tid : std::vector<Type::type>{Type::DECIMAL128, Type::DECIMAL256}) {
for (const auto& in_tid : std::vector<Type::type>{Type::DECIMAL32, Type::DECIMAL64,
Type::DECIMAL128, Type::DECIMAL256}) {
DCHECK_OK(
func->AddKernel(in_tid, {in_tid}, out_ty,
GenerateDecimal<DecimalToStringCastFunctor, OutType>(in_tid),
Expand Down
Loading

0 comments on commit 313d11a

Please sign in to comment.