-
Notifications
You must be signed in to change notification settings - Fork 169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Implement Spark-compatible CAST from string to integral types #307
Conversation
I am now working on refactoring to reduce code duplication by leveraging macros/generics. |
( | ||
DataType::Dictionary(key_type, value_type), | ||
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, | ||
) if key_type.as_ref() == &DataType::Int32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@viirya do you know if dictionary keys will always be i32?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've been assuming it to be so, though @viirya can give us the definitive answer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember in many places in native code, we assume that dictionary keys are always Int32 type.
But I forgot that where we make such assumption. 😅
cc @sunchao Do you remember that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see. I think the assumption comes from native scan side where the Parquet dictionary indices is always of integer type so dictionary keys read from native scan is always Int32 type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can check the DictDecoder
in native scan implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Except for that any operator or expression during execution produce a dictionary with keys other than Int32 type. But for that I think it should be considered a bug for us to fix because I don't think it makes sense to change dictionary key type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see. I think the assumption comes from native scan side where the Parquet dictionary indices is always of integer type so dictionary keys read from native scan is always Int32 type.
Yes that is exactly right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @sunchao for confirming it.
@@ -64,6 +68,25 @@ pub struct Cast { | |||
pub timezone: String, | |||
} | |||
|
|||
macro_rules! spark_cast_utf8_to_integral { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe utf8_to_integer
?
spark not involved in native exec, not sure why spark is needed.
Integral type also includes booleans and this scope limited by integers afaik
macro_rules! spark_cast_utf8_to_integral { | ||
($string_array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{ | ||
let mut cast_array = PrimitiveArray::<$array_type>::builder($string_array.len()); | ||
for i in 0..$string_array.len() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we probably can use iterator instead of for loop?
and lets calc $string_array.len()
once
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that $string_array.len()
is already only computed once?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see them on lines 73,74 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I missed that! Thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
Ok(spark_cast(cast_result, from_type, to_type)) | ||
} | ||
|
||
fn spark_cast_string_to_integral( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
string_to_int
?
Thanks @andygrove btw I'm wondering if this PR should cover scope with formatting https://spark.apache.org/docs/latest/sql-ref-number-pattern.html#the-to_number-function |
Co-authored-by: comphead <comphead@users.noreply.github.com>
…datafusion-comet into cast-string-to-integral
} | ||
|
||
ignore("cast string to short") { | ||
test("cast string to short") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably have some negative tests with invalid strings.
Also, curious, what does cast(".")
yield?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fuzz testing does generate many invalid inputs. I can add some more explicit ones to these tests, though.
cast(".")
will yield different results depending on the eval mode:
LEGACY
->0
TRY
-> nullANSI
-> error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for these test btw
Sorry, I'm not sure I understand. You are referring to the error message formatting? |
Oh it covers just cast string to integers, I thought |
let negative = chars[0] == '-'; | ||
if negative || chars[0] == '+' { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems wrong.
It should be chars[i] == '-'
instead? Otherwise, this cast doesn't work for -124
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! The code was originally trimming the string before this point and I missed updating this when I removed the trim. I have now fixed this.
use super::{cast_string_to_i8, EvalMode}; | ||
|
||
#[test] | ||
fn test_cast_string_as_i8() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about add more tests about i32
and i64
with its min/max and zero input?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am going to focus on improving the tests in this PR today
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have now added tests for all min/max boundary values in the Scala tests
@@ -103,10 +125,72 @@ impl Cast { | |||
(DataType::LargeUtf8, DataType::Boolean) => { | |||
Self::spark_cast_utf8_to_boolean::<i64>(&array, self.eval_mode)? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not part of this pr. But if we are going to name the added method as cast_string_to_int
.
This method should be renamed to cast_utf8_to_boolean
as well in a follow-up PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. I didn't want to start making unrelated changes in this PR, but we should rename this.
// Note that we are unpacking a dictionary-encoded array and then performing | ||
// the cast. We could potentially improve performance here by casting the | ||
// dictionary values directly without unpacking the array first, although this | ||
// would add more complexity to the code |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can leave a TODO to cast dictionary directly?
|
||
ignore("cast string to long") { | ||
castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.LongType) | ||
private val castStringToIntegralInputs: Seq[String] = Seq( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Since the cast code handles leading and trailing white spaces, I think we can add more input with white spaces.
For example:
castStringToIntegeralnputs.flatMap { x => Seq(" " + x, x + " ", " " + x + " ") }
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your effort @andygrove, the new code is well crafted.
@viirya @sunchao @parthchandra @comphead I did quite a bit of refactoring and performance tuning over the weekend. Please take another look when you can. |
Thank you for the thorough review @advancedxy! |
let len = $array.len(); | ||
let mut cast_array = PrimitiveArray::<$array_type>::builder(len); | ||
for i in 0..len { | ||
if $array.is_null(i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe it can be simplified to
if let Some(cast_value) = $cast_method($array.value(i).trim(), $eval_mode)? {
cast_array.append_value(cast_value);
} else {
cast_array.append_null()
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there is a null input then we will always want a null output and we don't want to add the overhead of calling the cast logic in this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm thanks @andygrove couple of minors
/// Either return Ok(None) or Err(CometError::CastInvalidValue) depending on the evaluation mode | ||
fn none_or_err<T>(eval_mode: EvalMode, type_name: &str, str: &str) -> CometResult<Option<T>> { | ||
match eval_mode { | ||
EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), | ||
_ => Ok(None), | ||
} | ||
} | ||
|
||
fn invalid_value(value: &str, from_type: &str, to_type: &str) -> CometError { | ||
CometError::CastInvalidValue { | ||
value: value.to_string(), | ||
from_type: from_type.to_string(), | ||
to_type: to_type.to_string(), | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these can be inline
function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I have updated this.
fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> CometResult<Option<i32>> { | ||
do_cast_string_to_int::<i32>(str, eval_mode, "INT", i32::MIN) | ||
} | ||
|
||
fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> CometResult<Option<i64>> { | ||
do_cast_string_to_int::<i64>(str, eval_mode, "BIGINT", i64::MIN) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why only i8 and i16 have range check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code is ported directly from Spark. This is the approach that is used there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Spark has IntWrapper
and LongWrapper
which are equivalent to do_cast_string_to_int::<i32>
and do_cast_string_to_int::<i64>
in this PR.
This is the logic for casting to byte in Spark. It uses IntWrapper
then casts to byte
.
public boolean toByte(IntWrapper intWrapper) {
if (toInt(intWrapper)) {
int intValue = intWrapper.value;
byte result = (byte) intValue;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I pushed a commit to add some comments referencing the Spark code that this code is based on
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@viirya Let me know if there is anything else to address. I have upmerged with latest from main branch so this PR is a little smaller now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will go to look at this again tonight.
// Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / | ||
// radix), we can just use `result > 0` to check overflow. If result | ||
// overflows, we should stop |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean "more than or equal to"? I think the above condition (L352) is already for result < stop_value
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment was copied from the Spark code in org/apache/spark/unsafe/types/UTF8String.java
, but I agree that it seems incorrect. I have updated it.
Co-authored-by: comphead <comphead@users.noreply.github.com>
Co-authored-by: comphead <comphead@users.noreply.github.com>
Looks good to me. Thanks @andygrove |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
return none_or_err(eval_mode, type_name, str); | ||
}; | ||
|
||
// We are going to process the new digit and accumulate the result. However, before |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A comment to explain why we're using subtraction instead of addition would make it easier to understand this part of the code.
Which issue does this PR close?
Part of #286
Closes #15
Rationale for this change
Improve compatibility with Apache Spark
What changes are included in this PR?
Add custom implementation of CAST from string to integral rather than delegate to DataFusion
How are these changes tested?