diff --git a/cpp/src/gandiva/function_registry_arithmetic.cc b/cpp/src/gandiva/function_registry_arithmetic.cc index 6613c1f12c6af..f34289f372ebb 100644 --- a/cpp/src/gandiva/function_registry_arithmetic.cc +++ b/cpp/src/gandiva/function_registry_arithmetic.cc @@ -29,9 +29,13 @@ namespace gandiva { #define BINARY_RELATIONAL_BOOL_DATE_FN(name, ALIASES) \ NUMERIC_DATE_TYPES(BINARY_RELATIONAL_SAFE_NULL_IF_NULL, name, ALIASES) -#define UNARY_CAST_TO_FLOAT64(name) UNARY_SAFE_NULL_IF_NULL(castFLOAT8, {}, name, float64) +#define UNARY_CAST_TO_FLOAT64(type) UNARY_SAFE_NULL_IF_NULL(castFLOAT8, {}, type, float64) -#define UNARY_CAST_TO_FLOAT32(name) UNARY_SAFE_NULL_IF_NULL(castFLOAT4, {}, name, float32) +#define UNARY_CAST_TO_FLOAT32(type) UNARY_SAFE_NULL_IF_NULL(castFLOAT4, {}, type, float32) + +#define UNARY_CAST_TO_INT32(type) UNARY_SAFE_NULL_IF_NULL(castINT, {}, type, int32) + +#define UNARY_CAST_TO_INT64(type) UNARY_SAFE_NULL_IF_NULL(castBIGINT, {}, type, int64) std::vector GetArithmeticFunctionRegistry() { static std::vector arithmetic_fn_registry_ = { @@ -44,6 +48,12 @@ std::vector GetArithmeticFunctionRegistry() { UNARY_CAST_TO_FLOAT32(int32), UNARY_CAST_TO_FLOAT32(int64), UNARY_CAST_TO_FLOAT32(float64), + // cast to int32 + UNARY_CAST_TO_INT32(float32), UNARY_CAST_TO_INT32(float64), + + // cast to int64 + UNARY_CAST_TO_INT64(float32), UNARY_CAST_TO_INT64(float64), + // cast to float64 UNARY_CAST_TO_FLOAT64(int32), UNARY_CAST_TO_FLOAT64(int64), UNARY_CAST_TO_FLOAT64(float32), UNARY_CAST_TO_FLOAT64(decimal128), diff --git a/cpp/src/gandiva/function_registry_common.h b/cpp/src/gandiva/function_registry_common.h index 40efc1fe1a978..66f945150897a 100644 --- a/cpp/src/gandiva/function_registry_common.h +++ b/cpp/src/gandiva/function_registry_common.h @@ -43,6 +43,7 @@ using arrow::int16; using arrow::int32; using arrow::int64; using arrow::int8; +using arrow::month_interval; using arrow::uint16; using arrow::uint32; using arrow::uint64; diff --git a/cpp/src/gandiva/function_registry_datetime.cc b/cpp/src/gandiva/function_registry_datetime.cc index 6e7a703aa61b8..b8d2e7b6c7dfe 100644 --- a/cpp/src/gandiva/function_registry_datetime.cc +++ b/cpp/src/gandiva/function_registry_datetime.cc @@ -93,6 +93,32 @@ std::vector GetDateTimeFunctionRegistry() { NativeFunction("castBIGINT", {}, DataTypeVector{day_time_interval()}, int64(), kResultNullIfNull, "castBIGINT_daytimeinterval"), + NativeFunction("castINT", {"castNULLABLEINT"}, DataTypeVector{month_interval()}, + int32(), kResultNullIfNull, "castINT_year_interval", + NativeFunction::kCanReturnErrors), + + NativeFunction("castBIGINT", {"castNULLABLEBIGINT"}, + DataTypeVector{month_interval()}, int64(), kResultNullIfNull, + "castBIGINT_year_interval", NativeFunction::kCanReturnErrors), + + NativeFunction("castNULLABLEINTERVALYEAR", {"castINTERVALYEAR"}, + DataTypeVector{int32()}, month_interval(), kResultNullIfNull, + "castNULLABLEINTERVALYEAR_int32", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("castNULLABLEINTERVALYEAR", {"castINTERVALYEAR"}, + DataTypeVector{int64()}, month_interval(), kResultNullIfNull, + "castNULLABLEINTERVALYEAR_int64", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("castNULLABLEINTERVALDAY", {"castINTERVALDAY"}, + DataTypeVector{int32()}, day_time_interval(), kResultNullIfNull, + "castNULLABLEINTERVALDAY_int32"), + + NativeFunction("castNULLABLEINTERVALDAY", {"castINTERVALDAY"}, + DataTypeVector{int64()}, day_time_interval(), kResultNullIfNull, + "castNULLABLEINTERVALDAY_int64"), + NativeFunction("extractDay", {}, DataTypeVector{day_time_interval()}, int64(), kResultNullIfNull, "extractDay_daytimeinterval"), diff --git a/cpp/src/gandiva/gdv_function_stubs.h b/cpp/src/gandiva/gdv_function_stubs.h index 7736b9750895e..670ac94df1b89 100644 --- a/cpp/src/gandiva/gdv_function_stubs.h +++ b/cpp/src/gandiva/gdv_function_stubs.h @@ -42,6 +42,7 @@ using gdv_timestamp = int64_t; using gdv_utf8 = char*; using gdv_binary = char*; using gdv_day_time_interval = int64_t; +using gdv_month_interval = int32_t; #ifdef GANDIVA_UNIT_TEST // unit tests may be compiled without O2, so inlining may not happen. diff --git a/cpp/src/gandiva/precompiled/arithmetic_ops.cc b/cpp/src/gandiva/precompiled/arithmetic_ops.cc index a173a60d6d000..c736c38d32c70 100644 --- a/cpp/src/gandiva/precompiled/arithmetic_ops.cc +++ b/cpp/src/gandiva/precompiled/arithmetic_ops.cc @@ -122,6 +122,21 @@ CAST_UNARY(castFLOAT4, float64, float32) #undef CAST_UNARY +// cast float types to int types. +#define CAST_INT_FLOAT(NAME, IN_TYPE, OUT_TYPE) \ + FORCE_INLINE \ + gdv_##OUT_TYPE NAME##_##IN_TYPE(gdv_##IN_TYPE in) { \ + gdv_##OUT_TYPE out = static_cast(round(in)); \ + return out; \ + } + +CAST_INT_FLOAT(castBIGINT, float32, int64) +CAST_INT_FLOAT(castBIGINT, float64, int64) +CAST_INT_FLOAT(castINT, float32, int32) +CAST_INT_FLOAT(castINT, float64, int32) + +#undef CAST_INT_FLOAT + // simple nullable functions, result value = fn(input validity) #define VALIDITY_OP(NAME, TYPE, OP) \ FORCE_INLINE \ diff --git a/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc b/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc index b3359ac7d6c70..36b50bcfdae8d 100644 --- a/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc +++ b/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc @@ -137,4 +137,44 @@ TEST(TestArithmeticOps, TestBitwiseOps) { EXPECT_EQ(bitwise_not_int64(0x0000000000000000), 0xFFFFFFFFFFFFFFFF); } +TEST(TestArithmeticOps, TestIntCastFloatDouble) { + // castINT from floats + EXPECT_EQ(castINT_float32(6.6f), 7); + EXPECT_EQ(castINT_float32(-6.6f), -7); + EXPECT_EQ(castINT_float32(-6.3f), -6); + EXPECT_EQ(castINT_float32(0.0f), 0); + EXPECT_EQ(castINT_float32(-0), 0); + + // castINT from doubles + EXPECT_EQ(castINT_float64(6.6), 7); + EXPECT_EQ(castINT_float64(-6.6), -7); + EXPECT_EQ(castINT_float64(-6.3), -6); + EXPECT_EQ(castINT_float64(0.0), 0); + EXPECT_EQ(castINT_float64(-0), 0); + EXPECT_EQ(castINT_float64(999999.99999999999999999999999), 1000000); + EXPECT_EQ(castINT_float64(-999999.99999999999999999999999), -1000000); + EXPECT_EQ(castINT_float64(INT32_MAX), 2147483647); + EXPECT_EQ(castINT_float64(-2147483647), -2147483647); +} + +TEST(TestArithmeticOps, TestBigIntCastFloatDouble) { + // castINT from floats + EXPECT_EQ(castBIGINT_float32(6.6f), 7); + EXPECT_EQ(castBIGINT_float32(-6.6f), -7); + EXPECT_EQ(castBIGINT_float32(-6.3f), -6); + EXPECT_EQ(castBIGINT_float32(0.0f), 0); + EXPECT_EQ(castBIGINT_float32(-0), 0); + + // castINT from doubles + EXPECT_EQ(castBIGINT_float64(6.6), 7); + EXPECT_EQ(castBIGINT_float64(-6.6), -7); + EXPECT_EQ(castBIGINT_float64(-6.3), -6); + EXPECT_EQ(castBIGINT_float64(0.0), 0); + EXPECT_EQ(castBIGINT_float64(-0), 0); + EXPECT_EQ(castBIGINT_float64(999999.99999999999999999999999), 1000000); + EXPECT_EQ(castBIGINT_float64(-999999.99999999999999999999999), -1000000); + EXPECT_EQ(castBIGINT_float64(INT32_MAX), 2147483647); + EXPECT_EQ(castBIGINT_float64(-2147483647), -2147483647); +} + } // namespace gandiva diff --git a/cpp/src/gandiva/precompiled/time.cc b/cpp/src/gandiva/precompiled/time.cc index e5cdd9de64f80..336f692267dd1 100644 --- a/cpp/src/gandiva/precompiled/time.cc +++ b/cpp/src/gandiva/precompiled/time.cc @@ -860,4 +860,35 @@ NUMERIC_TYPES(TO_TIMESTAMP) NUMERIC_TYPES(TO_TIME) +#define CAST_INT_YEAR_INTERVAL(TYPE, OUT_TYPE) \ + FORCE_INLINE \ + gdv_##OUT_TYPE TYPE##_year_interval(gdv_month_interval in) { \ + return static_cast(in / 12.0); \ + } + +CAST_INT_YEAR_INTERVAL(castBIGINT, int64) +CAST_INT_YEAR_INTERVAL(castINT, int32) + +#define CAST_NULLABLE_INTERVAL_DAY(TYPE) \ + FORCE_INLINE \ + gdv_day_time_interval castNULLABLEINTERVALDAY_##TYPE(gdv_##TYPE in) { \ + return static_cast(in); \ + } + +CAST_NULLABLE_INTERVAL_DAY(int32) +CAST_NULLABLE_INTERVAL_DAY(int64) + +#define CAST_NULLABLE_INTERVAL_YEAR(TYPE) \ + FORCE_INLINE \ + gdv_month_interval castNULLABLEINTERVALYEAR_##TYPE(int64_t context, gdv_##TYPE in) { \ + gdv_month_interval value = static_cast(in); \ + if (value != in) { \ + gdv_fn_context_set_error_msg(context, "Integer overflow"); \ + } \ + return value; \ + } + +CAST_NULLABLE_INTERVAL_YEAR(int32) +CAST_NULLABLE_INTERVAL_YEAR(int64) + } // extern "C" diff --git a/cpp/src/gandiva/precompiled/time_test.cc b/cpp/src/gandiva/precompiled/time_test.cc index 8d3cdccd6ff1e..cec3cf747c288 100644 --- a/cpp/src/gandiva/precompiled/time_test.cc +++ b/cpp/src/gandiva/precompiled/time_test.cc @@ -869,4 +869,48 @@ TEST(TestTime, TestToTimeNumeric) { EXPECT_EQ(expected_output, to_time_float64(3601.500)); } +TEST(TestTime, TestCastIntDayInterval) { + EXPECT_EQ(castBIGINT_daytimeinterval(10), 864000000); + EXPECT_EQ(castBIGINT_daytimeinterval(-100), -8640000001); + EXPECT_EQ(castBIGINT_daytimeinterval(-0), 0); +} + +TEST(TestTime, TestCastIntYearInterval) { + EXPECT_EQ(castINT_year_interval(24), 2); + EXPECT_EQ(castINT_year_interval(-24), -2); + EXPECT_EQ(castINT_year_interval(-23), -1); + + EXPECT_EQ(castBIGINT_year_interval(24), 2); + EXPECT_EQ(castBIGINT_year_interval(-24), -2); + EXPECT_EQ(castBIGINT_year_interval(-23), -1); +} + +TEST(TestTime, TestCastNullableInterval) { + ExecutionContext context; + auto context_ptr = reinterpret_cast(&context); + // Test castNULLABLEINTERVALDAY for int and bigint + EXPECT_EQ(castNULLABLEINTERVALDAY_int32(1), 1); + EXPECT_EQ(castNULLABLEINTERVALDAY_int32(12), 12); + EXPECT_EQ(castNULLABLEINTERVALDAY_int32(-55), -55); + EXPECT_EQ(castNULLABLEINTERVALDAY_int32(-1201), -1201); + EXPECT_EQ(castNULLABLEINTERVALDAY_int64(1), 1); + EXPECT_EQ(castNULLABLEINTERVALDAY_int64(12), 12); + EXPECT_EQ(castNULLABLEINTERVALDAY_int64(-55), -55); + EXPECT_EQ(castNULLABLEINTERVALDAY_int64(-1201), -1201); + + // Test castNULLABLEINTERVALYEAR for int and bigint + EXPECT_EQ(castNULLABLEINTERVALYEAR_int32(context_ptr, 1), 1); + EXPECT_EQ(castNULLABLEINTERVALYEAR_int32(context_ptr, 12), 12); + EXPECT_EQ(castNULLABLEINTERVALYEAR_int32(context_ptr, 55), 55); + EXPECT_EQ(castNULLABLEINTERVALYEAR_int32(context_ptr, 1201), 1201); + EXPECT_EQ(castNULLABLEINTERVALYEAR_int64(context_ptr, 1), 1); + EXPECT_EQ(castNULLABLEINTERVALYEAR_int64(context_ptr, 12), 12); + EXPECT_EQ(castNULLABLEINTERVALYEAR_int64(context_ptr, 55), 55); + EXPECT_EQ(castNULLABLEINTERVALYEAR_int64(context_ptr, 1201), 1201); + // validate overflow error when using bigint as input + castNULLABLEINTERVALYEAR_int64(context_ptr, INT64_MAX); + EXPECT_EQ(context.get_error(), "Integer overflow"); + context.Reset(); +} + } // namespace gandiva diff --git a/cpp/src/gandiva/precompiled/types.h b/cpp/src/gandiva/precompiled/types.h index 4e913aaac67a8..7032f45997432 100644 --- a/cpp/src/gandiva/precompiled/types.h +++ b/cpp/src/gandiva/precompiled/types.h @@ -543,4 +543,26 @@ float castFLOAT4_utf8(int64_t context, const char* data, int32_t len); double castFLOAT8_utf8(int64_t context, const char* data, int32_t len); +int32_t castINT_float32(gdv_float32 value); + +int32_t castINT_float64(gdv_float64 value); + +int64_t castBIGINT_float32(gdv_float32 value); + +int64_t castBIGINT_float64(gdv_float64 value); + +int64_t castBIGINT_daytimeinterval(gdv_day_time_interval in); + +int32_t castINT_year_interval(gdv_month_interval in); + +int64_t castBIGINT_year_interval(gdv_month_interval in); + +gdv_day_time_interval castNULLABLEINTERVALDAY_int32(gdv_int32 in); + +gdv_day_time_interval castNULLABLEINTERVALDAY_int64(gdv_int64 in); + +gdv_month_interval castNULLABLEINTERVALYEAR_int32(int64_t context, gdv_int32 in); + +gdv_month_interval castNULLABLEINTERVALYEAR_int64(int64_t context, gdv_int64 in); + } // extern "C" diff --git a/cpp/src/gandiva/tests/projector_test.cc b/cpp/src/gandiva/tests/projector_test.cc index 33f83a44c4c6c..2ce52befdc149 100644 --- a/cpp/src/gandiva/tests/projector_test.cc +++ b/cpp/src/gandiva/tests/projector_test.cc @@ -1409,4 +1409,163 @@ TEST_F(TestProjector, TestBinRepresentation) { EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); } +TEST_F(TestProjector, TestBigIntCastFunction) { + // input fields + auto field0 = field("f0", arrow::float32()); + auto field1 = field("f1", arrow::float64()); + auto field2 = field("f2", arrow::day_time_interval()); + auto field3 = field("f3", arrow::month_interval()); + auto schema = arrow::schema({field0, field1, field2, field3}); + + // output fields + auto res_int64 = field("res", arrow::int64()); + + // Build expression + auto cast_expr_float4 = + TreeExprBuilder::MakeExpression("castBIGINT", {field0}, res_int64); + auto cast_expr_float8 = + TreeExprBuilder::MakeExpression("castBIGINT", {field1}, res_int64); + auto cast_expr_day_interval = + TreeExprBuilder::MakeExpression("castBIGINT", {field2}, res_int64); + auto cast_expr_year_interval = + TreeExprBuilder::MakeExpression("castBIGINT", {field3}, res_int64); + + std::shared_ptr projector; + + // {cast_expr_float4, cast_expr_float8, cast_expr_day_interval, + // cast_expr_year_interval} + auto status = Projector::Make(schema, + {cast_expr_float4, cast_expr_float8, + cast_expr_day_interval, cast_expr_year_interval}, + TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + + // Last validity is false and the cast functions throw error when input is empty. Should + // not be evaluated due to addition of NativeFunction::kCanReturnErrors + auto array0 = + MakeArrowArrayFloat32({6.6f, -6.6f, 9.999999f, 0}, {true, true, true, false}); + auto array1 = + MakeArrowArrayFloat64({6.6, -6.6, 9.99999999999, 0}, {true, true, true, false}); + auto array2 = MakeArrowArrayInt64({100, 25, -0, 0}, {true, true, true, false}); + auto array3 = MakeArrowArrayInt32({25, -25, -0, 0}, {true, true, true, false}); + auto in_batch = + arrow::RecordBatch::Make(schema, num_records, {array0, array1, array2, array3}); + + auto out_float4 = MakeArrowArrayInt64({7, -7, 10, 0}, {true, true, true, false}); + auto out_float8 = MakeArrowArrayInt64({7, -7, 10, 0}, {true, true, true, false}); + auto out_days_interval = + MakeArrowArrayInt64({8640000000, 2160000000, 0, 0}, {true, true, true, false}); + auto out_year_interval = MakeArrowArrayInt64({2, -2, 0, 0}, {true, true, true, false}); + + arrow::ArrayVector outputs; + + // Evaluate expression + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + EXPECT_ARROW_ARRAY_EQUALS(out_float4, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(out_float8, outputs.at(1)); + EXPECT_ARROW_ARRAY_EQUALS(out_days_interval, outputs.at(2)); + EXPECT_ARROW_ARRAY_EQUALS(out_year_interval, outputs.at(3)); +} + +TEST_F(TestProjector, TestIntCastFunction) { + // input fields + auto field0 = field("f0", arrow::float32()); + auto field1 = field("f1", arrow::float64()); + auto field2 = field("f2", arrow::month_interval()); + auto schema = arrow::schema({field0, field1, field2}); + + // output fields + auto res_int32 = field("res", arrow::int32()); + + // Build expression + auto cast_expr_float4 = TreeExprBuilder::MakeExpression("castINT", {field0}, res_int32); + auto cast_expr_float8 = TreeExprBuilder::MakeExpression("castINT", {field1}, res_int32); + auto cast_expr_year_interval = + TreeExprBuilder::MakeExpression("castINT", {field2}, res_int32); + + std::shared_ptr projector; + + // {cast_expr_float4, cast_expr_float8, cast_expr_day_interval, + // cast_expr_year_interval} + auto status = Projector::Make( + schema, {cast_expr_float4, cast_expr_float8, cast_expr_year_interval}, + TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + + // Last validity is false and the cast functions throw error when input is empty. Should + // not be evaluated due to addition of NativeFunction::kCanReturnErrors + auto array0 = + MakeArrowArrayFloat32({6.6f, -6.6f, 9.999999f, 0}, {true, true, true, false}); + auto array1 = + MakeArrowArrayFloat64({6.6, -6.6, 9.99999999999, 0}, {true, true, true, false}); + auto array2 = MakeArrowArrayInt32({25, -25, -0, 0}, {true, true, true, false}); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1, array2}); + + auto out_float4 = MakeArrowArrayInt32({7, -7, 10, 0}, {true, true, true, false}); + auto out_float8 = MakeArrowArrayInt32({7, -7, 10, 0}, {true, true, true, false}); + auto out_year_interval = MakeArrowArrayInt32({2, -2, 0, 0}, {true, true, true, false}); + + arrow::ArrayVector outputs; + + // Evaluate expression + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + EXPECT_ARROW_ARRAY_EQUALS(out_float4, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(out_float8, outputs.at(1)); + EXPECT_ARROW_ARRAY_EQUALS(out_year_interval, outputs.at(2)); +} + +TEST_F(TestProjector, TestCastNullableIntYearInterval) { + // input fields + auto field1 = field("f1", arrow::month_interval()); + auto schema = arrow::schema({field1}); + + // output fields + auto res_int32 = field("res", arrow::int32()); + auto res_int64 = field("res", arrow::int64()); + + // Build expression + auto cast_expr_int32 = + TreeExprBuilder::MakeExpression("castNULLABLEINT", {field1}, res_int32); + auto cast_expr_int64 = + TreeExprBuilder::MakeExpression("castNULLABLEBIGINT", {field1}, res_int64); + + std::shared_ptr projector; + + // {cast_expr_int32, cast_expr_int64, cast_expr_day_interval, + // cast_expr_year_interval} + auto status = Projector::Make(schema, {cast_expr_int32, cast_expr_int64}, + TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + + // Last validity is false and the cast functions throw error when input is empty. Should + // not be evaluated due to addition of NativeFunction::kCanReturnErrors + auto array0 = MakeArrowArrayInt32({12, -24, -0, 0}, {true, true, true, false}); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0}); + + auto out_int32 = MakeArrowArrayInt32({1, -2, -0, 0}, {true, true, true, false}); + auto out_int64 = MakeArrowArrayInt64({1, -2, -0, 0}, {true, true, true, false}); + + arrow::ArrayVector outputs; + + // Evaluate expression + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + EXPECT_ARROW_ARRAY_EQUALS(out_int32, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(out_int64, outputs.at(1)); +} + } // namespace gandiva