Skip to content

Commit

Permalink
ARROW-13289: [C++] Accept integer args in trig/log functions via prom…
Browse files Browse the repository at this point in the history
…otion to double

Instead of adding/generating separate kernels for integers, just promote the arguments instead.

Closes apache#10686 from lidavidm/arrow-13289

Authored-by: David Li <li.davidm96@gmail.com>
Signed-off-by: David Li <li.davidm96@gmail.com>
  • Loading branch information
lidavidm committed Jul 12, 2021
1 parent ba009fb commit 090e2cf
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 4 deletions.
40 changes: 37 additions & 3 deletions cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,37 @@ struct ArithmeticFunction : ScalarFunction {
}
};

/// An ArithmeticFunction that promotes integer arguments to double.
struct ArithmeticFloatingPointFunction : public ArithmeticFunction {
using ArithmeticFunction::ArithmeticFunction;

Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override {
RETURN_NOT_OK(CheckArity(*values));
RETURN_NOT_OK(CheckDecimals(values));

using arrow::compute::detail::DispatchExactImpl;
if (auto kernel = DispatchExactImpl(this, *values)) return kernel;

EnsureDictionaryDecoded(values);

if (values->size() == 2) {
ReplaceNullWithOtherType(values);
}

for (auto& descr : *values) {
if (is_integer(descr.type->id())) {
descr.type = float64();
}
}
if (auto type = CommonNumeric(*values)) {
ReplaceTypes(type, values);
}

if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
return arrow::compute::detail::NoMatchingKernel(this, *values);
}
};

template <typename Op>
std::shared_ptr<ScalarFunction> MakeArithmeticFunction(std::string name,
const FunctionDoc* doc) {
Expand Down Expand Up @@ -1164,7 +1195,8 @@ std::shared_ptr<ScalarFunction> MakeShiftFunctionNotNull(std::string name,
template <typename Op>
std::shared_ptr<ScalarFunction> MakeUnaryArithmeticFunctionFloatingPoint(
std::string name, const FunctionDoc* doc) {
auto func = std::make_shared<ArithmeticFunction>(name, Arity::Unary(), doc);
auto func =
std::make_shared<ArithmeticFloatingPointFunction>(name, Arity::Unary(), doc);
for (const auto& ty : FloatingPointTypes()) {
auto output = is_integer(ty->id()) ? float64() : ty;
auto exec = GenerateArithmeticFloatingPoint<ScalarUnary, Op>(ty);
Expand All @@ -1176,7 +1208,8 @@ std::shared_ptr<ScalarFunction> MakeUnaryArithmeticFunctionFloatingPoint(
template <typename Op>
std::shared_ptr<ScalarFunction> MakeUnaryArithmeticFunctionFloatingPointNotNull(
std::string name, const FunctionDoc* doc) {
auto func = std::make_shared<ArithmeticFunction>(name, Arity::Unary(), doc);
auto func =
std::make_shared<ArithmeticFloatingPointFunction>(name, Arity::Unary(), doc);
for (const auto& ty : FloatingPointTypes()) {
auto output = is_integer(ty->id()) ? float64() : ty;
auto exec = GenerateArithmeticFloatingPoint<ScalarUnaryNotNull, Op>(ty);
Expand All @@ -1188,7 +1221,8 @@ std::shared_ptr<ScalarFunction> MakeUnaryArithmeticFunctionFloatingPointNotNull(
template <typename Op>
std::shared_ptr<ScalarFunction> MakeArithmeticFunctionFloatingPoint(
std::string name, const FunctionDoc* doc) {
auto func = std::make_shared<ArithmeticFunction>(name, Arity::Binary(), doc);
auto func =
std::make_shared<ArithmeticFloatingPointFunction>(name, Arity::Binary(), doc);
for (const auto& ty : FloatingPointTypes()) {
auto output = is_integer(ty->id()) ? float64() : ty;
auto exec = GenerateArithmeticFloatingPoint<ScalarBinaryEqualTypes, Op>(ty);
Expand Down
94 changes: 93 additions & 1 deletion cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1042,6 +1042,26 @@ TEST(TestUnaryArithmetic, DispatchBest) {
for (std::string name : {"negate", "negate_checked", "abs", "abs_checked"}) {
CheckDispatchFails(name, {null()});
}

for (std::string name :
{"ln", "log2", "log10", "log1p", "sin", "cos", "tan", "asin", "acos"}) {
for (std::string suffix : {"", "_checked"}) {
name += suffix;

CheckDispatchBest(name, {int32()}, {float64()});
CheckDispatchBest(name, {uint8()}, {float64()});

CheckDispatchBest(name, {dictionary(int8(), int64())}, {float64()});
}
}

CheckDispatchBest("atan", {int32()}, {float64()});
CheckDispatchBest("atan2", {int32(), float64()}, {float64(), float64()});
CheckDispatchBest("atan2", {int32(), uint8()}, {float64(), float64()});
CheckDispatchBest("atan2", {int32(), null()}, {float64(), float64()});
CheckDispatchBest("atan2", {float32(), float64()}, {float64(), float64()});
// Integer always promotes to double
CheckDispatchBest("atan2", {float32(), int8()}, {float64(), float64()});
}

TYPED_TEST(TestUnaryArithmeticSigned, Negate) {
Expand Down Expand Up @@ -1821,9 +1841,41 @@ TYPED_TEST(TestBinaryArithmeticFloating, TrigAtan2) {
-M_PI_2, 0, M_PI));
}

TYPED_TEST(TestUnaryArithmeticIntegral, Trig) {
// Integer arguments promoted to double, sanity check here
auto ty = this->type_singleton();
auto atan = [](const Datum& arg, ArithmeticOptions, ExecContext* ctx) {
return Atan(arg, ctx);
};
for (auto check_overflow : {false, true}) {
this->SetOverflowCheck(check_overflow);
this->AssertUnaryOp(Sin, ArrayFromJSON(ty, "[0, 1]"),
ArrayFromJSON(float64(), "[0, 0.8414709848078965]"));
this->AssertUnaryOp(Cos, ArrayFromJSON(ty, "[0, 1]"),
ArrayFromJSON(float64(), "[1, 0.5403023058681398]"));
this->AssertUnaryOp(Tan, ArrayFromJSON(ty, "[0, 1]"),
ArrayFromJSON(float64(), "[0, 1.5574077246549023]"));
this->AssertUnaryOp(Asin, ArrayFromJSON(ty, "[0, 1]"),
ArrayFromJSON(float64(), MakeArray(0, M_PI_2)));
this->AssertUnaryOp(Acos, ArrayFromJSON(ty, "[0, 1]"),
ArrayFromJSON(float64(), MakeArray(M_PI_2, 0)));
this->AssertUnaryOp(atan, ArrayFromJSON(ty, "[0, 1]"),
ArrayFromJSON(float64(), MakeArray(0, M_PI_4)));
}
}

TYPED_TEST(TestBinaryArithmeticIntegral, Trig) {
// Integer arguments promoted to double, sanity check here
auto ty = this->type_singleton();
auto atan2 = [](const Datum& y, const Datum& x, ArithmeticOptions, ExecContext* ctx) {
return Atan2(y, x, ctx);
};
this->AssertBinop(atan2, ArrayFromJSON(ty, "[0, 1]"), ArrayFromJSON(ty, "[1, 0]"),
ArrayFromJSON(float64(), MakeArray(0, M_PI_2)));
}

TYPED_TEST(TestUnaryArithmeticFloating, Log) {
using CType = typename TestFixture::CType;
auto ty = this->type_singleton();
this->SetNansEqual(true);
auto min_val = std::numeric_limits<CType>::min();
auto max_val = std::numeric_limits<CType>::max();
Expand Down Expand Up @@ -1881,5 +1933,45 @@ TYPED_TEST(TestUnaryArithmeticFloating, Log) {
Log1p(lowest_val, this->options_));
}

TYPED_TEST(TestUnaryArithmeticIntegral, Log) {
// Integer arguments promoted to double, sanity check here
auto ty = this->type_singleton();
for (auto check_overflow : {false, true}) {
this->SetOverflowCheck(check_overflow);
this->AssertUnaryOp(Ln, ArrayFromJSON(ty, "[1, null]"),
ArrayFromJSON(float64(), "[0, null]"));
this->AssertUnaryOp(Log10, ArrayFromJSON(ty, "[1, 10, null]"),
ArrayFromJSON(float64(), "[0, 1, null]"));
this->AssertUnaryOp(Log2, ArrayFromJSON(ty, "[1, 2, null]"),
ArrayFromJSON(float64(), "[0, 1, null]"));
this->AssertUnaryOp(Log1p, ArrayFromJSON(ty, "[0, null]"),
ArrayFromJSON(float64(), "[0, null]"));
}
}

TYPED_TEST(TestUnaryArithmeticSigned, Log) {
// Integer arguments promoted to double, sanity check here
auto ty = this->type_singleton();
this->SetNansEqual(true);
this->SetOverflowCheck(false);
this->AssertUnaryOp(Ln, ArrayFromJSON(ty, "[-1, 0]"),
ArrayFromJSON(float64(), "[NaN, -Inf]"));
this->AssertUnaryOp(Log10, ArrayFromJSON(ty, "[-1, 0]"),
ArrayFromJSON(float64(), "[NaN, -Inf]"));
this->AssertUnaryOp(Log2, ArrayFromJSON(ty, "[-1, 0]"),
ArrayFromJSON(float64(), "[NaN, -Inf]"));
this->AssertUnaryOp(Log1p, ArrayFromJSON(ty, "[-2, -1]"),
ArrayFromJSON(float64(), "[NaN, -Inf]"));
this->SetOverflowCheck(true);
this->AssertUnaryOpRaises(Ln, "[0]", "logarithm of zero");
this->AssertUnaryOpRaises(Ln, "[-1]", "logarithm of negative number");
this->AssertUnaryOpRaises(Log10, "[0]", "logarithm of zero");
this->AssertUnaryOpRaises(Log10, "[-1]", "logarithm of negative number");
this->AssertUnaryOpRaises(Log2, "[0]", "logarithm of zero");
this->AssertUnaryOpRaises(Log2, "[-1]", "logarithm of negative number");
this->AssertUnaryOpRaises(Log1p, "[-1]", "logarithm of zero");
this->AssertUnaryOpRaises(Log1p, "[-2]", "logarithm of negative number");
}

} // namespace compute
} // namespace arrow

0 comments on commit 090e2cf

Please sign in to comment.