Skip to content

Commit

Permalink
Add castINT and castBIGINT from dayinterval and yearinterval
Browse files Browse the repository at this point in the history
  • Loading branch information
jvictorhuguenin committed Jul 16, 2021
1 parent d98af88 commit 5c6a7f3
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 1 deletion.
1 change: 0 additions & 1 deletion cpp/src/gandiva/function_registry_arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ std::vector<NativeFunction> GetArithmeticFunctionRegistry() {
UNARY_CAST_TO_INT32(float32), UNARY_CAST_TO_INT32(float64),

// cast to int64

UNARY_CAST_TO_INT64(float32), UNARY_CAST_TO_INT64(float64),

// cast to float64
Expand Down
1 change: 1 addition & 0 deletions cpp/src/gandiva/function_registry_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ using arrow::boolean;
using arrow::date32;
using arrow::date64;
using arrow::day_time_interval;
using arrow::month_interval;
using arrow::float32;
using arrow::float64;
using arrow::int16;
Expand Down
12 changes: 12 additions & 0 deletions cpp/src/gandiva/function_registry_datetime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,18 @@ std::vector<NativeFunction> GetDateTimeFunctionRegistry() {
NativeFunction("castBIGINT", {}, DataTypeVector{day_time_interval()}, int64(),
kResultNullIfNull, "castBIGINT_daytimeinterval"),

NativeFunction("castINT", {}, DataTypeVector{day_time_interval()}, int32(),
kResultNullIfNull, "castINT_dayinterval", NativeFunction::kCanReturnErrors),

NativeFunction("castBIGINT", {}, DataTypeVector{day_time_interval()}, int64(),
kResultNullIfNull, "castBIGINT_dayinterval", NativeFunction::kCanReturnErrors),

NativeFunction("castINT", {}, DataTypeVector{month_interval()}, int32(),
kResultNullIfNull, "castINT_year_interval", NativeFunction::kCanReturnErrors),

NativeFunction("castBIGINT", {}, DataTypeVector{month_interval()}, int64(),
kResultNullIfNull, "castBIGINT_year_interval", NativeFunction::kCanReturnErrors),

NativeFunction("extractDay", {}, DataTypeVector{day_time_interval()}, int64(),
kResultNullIfNull, "extractDay_daytimeinterval"),

Expand Down
1 change: 1 addition & 0 deletions cpp/src/gandiva/gdv_function_stubs.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ using gdv_timestamp = int64_t;
using gdv_utf8 = char*;
using gdv_binary = char*;
using gdv_day_time_interval = int64_t;
using gdv_year_interval = int64_t;

#ifdef GANDIVA_UNIT_TEST
// unit tests may be compiled without O2, so inlining may not happen.
Expand Down
19 changes: 19 additions & 0 deletions cpp/src/gandiva/precompiled/time.cc
Original file line number Diff line number Diff line change
Expand Up @@ -860,4 +860,23 @@ NUMERIC_TYPES(TO_TIMESTAMP)

NUMERIC_TYPES(TO_TIME)

#define CAST_INT_DAY_INTERVAL(NAME, OUT_TYPE) \
gdv_##OUT_TYPE NAME##_dayinterval(gdv_day_time_interval in){ \
return static_cast<gdv_##OUT_TYPE>(in & 0x00000000FFFFFFFF);\
}\

CAST_INT_DAY_INTERVAL(castBIGINT, int64)
CAST_INT_DAY_INTERVAL(castINT, int32)
#undef CAST_INT_DAY_INTERVAL

#define CAST_INT_YEAR_INTERVAL(NAME, OUT_TYPE) \
gdv_##OUT_TYPE NAME##_year_interval(gdv_year_interval in){ \
return static_cast<gdv_##OUT_TYPE>(in/12.0); \
} \

CAST_INT_YEAR_INTERVAL(castBIGINT, int64)
CAST_INT_YEAR_INTERVAL(castINT, int32)

#undef CAST_INT_YEAR_INTERVAL

} // extern "C"
20 changes: 20 additions & 0 deletions cpp/src/gandiva/precompiled/time_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -869,4 +869,24 @@ TEST(TestTime, TestToTimeNumeric) {
EXPECT_EQ(expected_output, to_time_float64(3601.500));
}

TEST(TestTime, TestCastIntDayInterval){
EXPECT_EQ(castINT_dayinterval(10),10);
EXPECT_EQ(castINT_dayinterval(-100),-100);
EXPECT_EQ(castINT_dayinterval(-0),0);

EXPECT_EQ(castBIGINT_dayinterval(10),10);
EXPECT_EQ(castBIGINT_dayinterval(-100),-100);
EXPECT_EQ(castBIGINT_dayinterval(-0),0);
}

TEST(TestTime, TestCastIntYearInterval){
EXPECT_EQ(castINT_year_interval(24),2);
EXPECT_EQ(castINT_year_interval(-24),-2);
EXPECT_EQ(castINT_year_interval(-23),-1);

EXPECT_EQ(castBIGINT_year_interval(24),2);
EXPECT_EQ(castBIGINT_year_interval(-24),-2);
EXPECT_EQ(castBIGINT_year_interval(-23),-1);
}

} // namespace gandiva
10 changes: 10 additions & 0 deletions cpp/src/gandiva/precompiled/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -551,4 +551,14 @@ int64_t castBIGINT_float32(gdv_float32 value);

int64_t castBIGINT_float64(gdv_float64 value);

int64_t castBIGINT_daytimeinterval(gdv_day_time_interval in);

int32_t castINT_dayinterval(gdv_day_time_interval in);

int64_t castBIGINT_dayinterval(gdv_day_time_interval in);

int32_t castINT_year_interval(gdv_year_interval in);

int64_t castBIGINT_year_interval(gdv_year_interval in);

} // extern "C"
55 changes: 55 additions & 0 deletions cpp/src/gandiva/tests/projector_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1353,4 +1353,59 @@ TEST_F(TestProjector, TestBinRepresentation) {
EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
}

TEST_F(TestProjector, TestBigIntCastFunction) {
// input fields
auto field0 = field("f0", arrow::float32());
auto field1 = field("f1", arrow::float64());
auto field2 = field("f2", arrow::day_time_interval());
auto field3 = field("f3", arrow::month_interval());
auto schema = arrow::schema({field0, field1, field2, field3});

// output fields
auto res_int64 = field("res", arrow::int64());

// Build expression
auto cast_expr_float4 =
TreeExprBuilder::MakeExpression("castBIGINT", {field0}, res_int64);
auto cast_expr_float8 =
TreeExprBuilder::MakeExpression("castBIGINT", {field1}, res_int64);
auto cast_expr_day_interval = TreeExprBuilder::MakeExpression("castBIGINT", {field2}, res_int64);
auto cast_expr_year_interval = TreeExprBuilder::MakeExpression("castBIGINT", {field3}, res_int64);

std::shared_ptr<Projector> projector;

// {cast_expr_float4, cast_expr_float8, cast_expr_day_interval, cast_expr_year_interval}
auto status = Projector::Make(
schema, {cast_expr_float4, cast_expr_float8, cast_expr_day_interval, cast_expr_year_interval},
TestConfiguration(), &projector);
EXPECT_TRUE(status.ok());

// Create a row-batch with some sample data
int num_records = 4;

// Last validity is false and the cast functions throw error when input is empty. Should
// not be evaluated due to addition of NativeFunction::kCanReturnErrors
auto array0 = MakeArrowArrayUtf8({"6.6", "-6,6", "9.999999", ""}, {true, true, true, false});
auto array1 = MakeArrowArrayUtf8({"6.6", "-6.6", "9.99999999999", ""}, {true, true, true, false});
auto array2 = MakeArrowArrayUtf8({"100", "-25", "-0", ""}, {true, true, true, false});
auto array3 = MakeArrowArrayUtf8({"25", "-25", "-0", ""}, {true, true, true, false});
auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1 ,array2, array3});

auto out_float4 = MakeArrowArrayInt64({7, 7, 10, 0}, {true, true, true, false});
auto out_float8 = MakeArrowArrayInt64({7, 7, 10, 0}, {true, true, true, false});
auto out_days_interval = MakeArrowArrayInt64({100, -25, 0, 0}, {true, true, true, false});
auto out_year_interval = MakeArrowArrayInt64({2, -2, 0, 0}, {true, true, true, false});

arrow::ArrayVector outputs;

// Evaluate expression
status = projector->Evaluate(*in_batch, pool_, &outputs);
EXPECT_TRUE(status.ok());

EXPECT_ARROW_ARRAY_EQUALS(out_float4, outputs.at(0));
EXPECT_ARROW_ARRAY_EQUALS(out_float8, outputs.at(1));
EXPECT_ARROW_ARRAY_EQUALS(out_days_interval, outputs.at(2));
EXPECT_ARROW_ARRAY_EQUALS(out_days_interval, outputs.at(3));
}

} // namespace gandiva

0 comments on commit 5c6a7f3

Please sign in to comment.