From b5e3e6c50c522494e324fb98535112b8ecdaa7a9 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 26 Aug 2024 10:56:25 -0600 Subject: [PATCH 1/5] basic version of string to float/double/decimal --- docs/source/user-guide/compatibility.md | 2 + .../core/src/execution/datafusion/planner.rs | 12 +- native/proto/src/proto/expr.proto | 4 +- native/spark-expr/src/cast.rs | 43 +++++- .../apache/comet/expressions/CometCast.scala | 6 +- .../apache/comet/serde/QueryPlanSerde.scala | 3 +- .../org/apache/comet/CometCastSuite.scala | 138 +++++++++++------- 7 files changed, 146 insertions(+), 62 deletions(-) diff --git a/docs/source/user-guide/compatibility.md b/docs/source/user-guide/compatibility.md index 0af44eb62..739ab6342 100644 --- a/docs/source/user-guide/compatibility.md +++ b/docs/source/user-guide/compatibility.md @@ -136,6 +136,8 @@ The following cast operations are not compatible with Spark for all inputs and a |-|-|-| | integer | decimal | No overflow check | | long | decimal | No overflow check | +| string | float | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. | +| string | double | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. | | string | timestamp | Not all valid formats are supported | | binary | string | Only works for binary data representing valid UTF-8 strings | diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index b0137bf85..aee0471a5 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -382,8 +382,13 @@ impl PhysicalPlanner { let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); let timezone = expr.timezone.clone(); let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; - - Ok(Arc::new(Cast::new(child, datatype, eval_mode, timezone))) + Ok(Arc::new(Cast::new( + child, + datatype, + eval_mode, + timezone, + expr.allow_incompat, + ))) } ExprStruct::Hour(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; @@ -723,17 +728,20 @@ impl PhysicalPlanner { left, DataType::Decimal256(p1, s1), EvalMode::Legacy, + false, )); let right = Arc::new(Cast::new_without_timezone( right, DataType::Decimal256(p2, s2), EvalMode::Legacy, + false, )); let child = Arc::new(BinaryExpr::new(left, op, right)); Ok(Arc::new(Cast::new_without_timezone( child, data_type, EvalMode::Legacy, + false, ))) } ( diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 50ab8f514..fa8f79ace 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -254,8 +254,8 @@ message Cast { Expr child = 1; DataType datatype = 2; string timezone = 3; - EvalMode eval_mode = 4; - + EvalMode eval_mode = 4; + bool allow_incompat = 5; } message Equal { diff --git a/native/spark-expr/src/cast.rs b/native/spark-expr/src/cast.rs index d8c03d897..dfe64a008 100644 --- a/native/spark-expr/src/cast.rs +++ b/native/spark-expr/src/cast.rs @@ -142,6 +142,8 @@ pub struct Cast { /// When cast from/to timezone related types, we need timezone, which will be resolved with /// session local timezone by an analyzer in Spark. pub timezone: String, + + pub allow_incompat: bool, } macro_rules! cast_utf8_to_int { @@ -545,12 +547,14 @@ impl Cast { data_type: DataType, eval_mode: EvalMode, timezone: String, + allow_incompat: bool, ) -> Self { Self { child, data_type, timezone, eval_mode, + allow_incompat, } } @@ -558,12 +562,14 @@ impl Cast { child: Arc, data_type: DataType, eval_mode: EvalMode, + allow_incompat: bool, ) -> Self { Self { child, data_type, timezone: "".to_string(), eval_mode, + allow_incompat, } } } @@ -576,6 +582,7 @@ pub fn spark_cast( data_type: &DataType, eval_mode: EvalMode, timezone: String, + allow_incompat: bool, ) -> DataFusionResult { match arg { ColumnarValue::Array(array) => Ok(ColumnarValue::Array(cast_array( @@ -583,6 +590,7 @@ pub fn spark_cast( data_type, eval_mode, timezone.to_owned(), + allow_incompat, )?)), ColumnarValue::Scalar(scalar) => { // Note that normally CAST(scalar) should be fold in Spark JVM side. However, for @@ -590,7 +598,13 @@ pub fn spark_cast( // here. let array = scalar.to_array()?; let scalar = ScalarValue::try_from_array( - &cast_array(array, data_type, eval_mode, timezone.to_owned())?, + &cast_array( + array, + data_type, + eval_mode, + timezone.to_owned(), + allow_incompat, + )?, 0, )?; Ok(ColumnarValue::Scalar(scalar)) @@ -603,6 +617,7 @@ fn cast_array( to_type: &DataType, eval_mode: EvalMode, timezone: String, + allow_incompat: bool, ) -> DataFusionResult { let array = array_with_timezone(array, timezone.clone(), Some(to_type))?; let from_type = array.data_type().clone(); @@ -624,6 +639,7 @@ fn cast_array( to_type, eval_mode, timezone, + allow_incompat, )?, ); @@ -693,7 +709,7 @@ fn cast_array( { spark_cast_nonintegral_numeric_to_integral(&array, eval_mode, from_type, to_type) } - _ if is_datafusion_spark_compatible(from_type, to_type) => { + _ if is_datafusion_spark_compatible(from_type, to_type, allow_incompat) => { // use DataFusion cast only when we know that it is compatible with Spark Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?) } @@ -711,7 +727,11 @@ fn cast_array( /// Determines if DataFusion supports the given cast in a way that is /// compatible with Spark -fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> bool { +fn is_datafusion_spark_compatible( + from_type: &DataType, + to_type: &DataType, + allow_incompat: bool, +) -> bool { if from_type == to_type { return true; } @@ -764,6 +784,10 @@ fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> b | DataType::Decimal128(_, _) | DataType::Decimal256(_, _) ), + DataType::Utf8 if allow_incompat => matches!( + to_type, + DataType::Binary | DataType::Float32 | DataType::Float64 | DataType::Decimal128(_, _) + ), DataType::Utf8 => matches!(to_type, DataType::Binary), DataType::Date32 => matches!(to_type, DataType::Utf8), DataType::Timestamp(_, _) => { @@ -1385,7 +1409,13 @@ impl PhysicalExpr for Cast { fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { let arg = self.child.evaluate(batch)?; - spark_cast(arg, &self.data_type, self.eval_mode, self.timezone.clone()) + spark_cast( + arg, + &self.data_type, + self.eval_mode, + self.timezone.clone(), + self.allow_incompat, + ) } fn children(&self) -> Vec<&Arc> { @@ -1402,6 +1432,7 @@ impl PhysicalExpr for Cast { self.data_type.clone(), self.eval_mode, self.timezone.clone(), + self.allow_incompat, ))), _ => internal_err!("Cast should have exactly one child"), } @@ -1413,6 +1444,7 @@ impl PhysicalExpr for Cast { self.data_type.hash(&mut s); self.timezone.hash(&mut s); self.eval_mode.hash(&mut s); + self.allow_incompat.hash(&mut s); self.hash(&mut s); } } @@ -1996,6 +2028,7 @@ mod tests { &DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.clone().into())), EvalMode::Legacy, timezone.clone(), + false, )?; assert_eq!( *result.data_type(), @@ -2205,6 +2238,7 @@ mod tests { &DataType::Date32, EvalMode::Legacy, "UTC".to_owned(), + false, ); assert!(result.is_err()) } @@ -2217,6 +2251,7 @@ mod tests { &DataType::Date32, EvalMode::Legacy, "Not a valid timezone".to_owned(), + false, ); assert!(result.is_err()) } diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index 811c61d46..139a2a1b5 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -113,10 +113,12 @@ object CometCast { Compatible() case DataTypes.FloatType | DataTypes.DoubleType => // https://github.com/apache/datafusion-comet/issues/326 - Unsupported + Incompatible(Some( + "Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode.")) case _: DecimalType => // https://github.com/apache/datafusion-comet/issues/325 - Unsupported + Incompatible(Some( + "Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits")) case DataTypes.DateType => // https://github.com/apache/datafusion-comet/issues/327 Compatible(Some("Only supports years between 262143 BC and 262142 AD")) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index cfb847644..45f9d1992 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -792,7 +792,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim castBuilder.setChild(childExpr.get) castBuilder.setDatatype(dataType.get) castBuilder.setEvalMode(evalModeToProto(evalMode)) - + castBuilder.setAllowIncompat(CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get()) val timeZone = timeZoneId.getOrElse("UTC") castBuilder.setTimezone(timeZone) @@ -1506,6 +1506,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim .setChild(e) .setDatatype(serializeDataType(IntegerType).get) .setEvalMode(ExprOuterClass.EvalMode.LEGACY) + .setAllowIncompat(false) .build()) .build() }) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 833b77d5d..4bcc3e00b 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -555,17 +555,51 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.FloatType) } + test("cast StringType to FloatType (partial support)") { + withSQLConf( + CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true", + SQLConf.ANSI_ENABLED.key -> "false") { + castTest( + gen.generateStrings(dataSize, "0123456789.", 8).toDF("a"), + DataTypes.FloatType, + testAnsi = false) + } + } + ignore("cast StringType to DoubleType") { // https://github.com/apache/datafusion-comet/issues/326 castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.DoubleType) } + test("cast StringType to DoubleType (partial support)") { + withSQLConf( + CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true", + SQLConf.ANSI_ENABLED.key -> "false") { + castTest( + gen.generateStrings(dataSize, "0123456789.", 8).toDF("a"), + DataTypes.DoubleType, + testAnsi = false) + } + } + ignore("cast StringType to DecimalType(10,2)") { // https://github.com/apache/datafusion-comet/issues/325 val values = gen.generateStrings(dataSize, numericPattern, 8).toDF("a") castTest(values, DataTypes.createDecimalType(10, 2)) } + test("cast StringType to DecimalType(10,2) (partial support)") { + withSQLConf( + CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true", + SQLConf.ANSI_ENABLED.key -> "false") { + val values = gen + .generateStrings(dataSize, "0123456789.", 8) + .filter(_.exists(_.isDigit)) + .toDF("a") + castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = false) + } + } + test("cast StringType to BinaryType") { castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.BinaryType) } @@ -963,7 +997,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - private def castTest(input: DataFrame, toType: DataType): Unit = { + private def castTest(input: DataFrame, toType: DataType, testAnsi: Boolean = true): Unit = { // we now support the TryCast expression in Spark 3.3 withTempPath { dir => @@ -981,60 +1015,62 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { checkSparkAnswerAndOperator(df2) } - // with ANSI enabled, we should produce the same exception as Spark - withSQLConf( - (SQLConf.ANSI_ENABLED.key, "true"), - (CometConf.COMET_ANSI_MODE_ENABLED.key, "true")) { - - // cast() should throw exception on invalid inputs when ansi mode is enabled - val df = data.withColumn("converted", col("a").cast(toType)) - checkSparkMaybeThrows(df) match { - case (None, None) => - // neither system threw an exception - case (None, Some(e)) => - // Spark succeeded but Comet failed - throw e - case (Some(e), None) => - // Spark failed but Comet succeeded - fail(s"Comet should have failed with ${e.getCause.getMessage}") - case (Some(sparkException), Some(cometException)) => - // both systems threw an exception so we make sure they are the same - val sparkMessage = - if (sparkException.getCause != null) sparkException.getCause.getMessage - else sparkException.getMessage - val cometMessage = - if (cometException.getCause != null) cometException.getCause.getMessage - else cometException.getMessage - if (CometSparkSessionExtensions.isSpark40Plus) { - // for Spark 4 we expect to sparkException carries the message - assert( - sparkException.getMessage - .replace(".WITH_SUGGESTION] ", "]") - .startsWith(cometMessage)) - } else if (CometSparkSessionExtensions.isSpark34Plus) { - // for Spark 3.4 we expect to reproduce the error message exactly - assert(cometMessage == sparkMessage) - } else { - // for Spark 3.3 we just need to strip the prefix from the Comet message - // before comparing - val cometMessageModified = cometMessage - .replace("[CAST_INVALID_INPUT] ", "") - .replace("[CAST_OVERFLOW] ", "") - .replace("[NUMERIC_VALUE_OUT_OF_RANGE] ", "") - - if (sparkMessage.contains("cannot be represented as")) { - assert(cometMessage.contains("cannot be represented as")) + if (testAnsi) { + // with ANSI enabled, we should produce the same exception as Spark + withSQLConf( + (SQLConf.ANSI_ENABLED.key, "true"), + (CometConf.COMET_ANSI_MODE_ENABLED.key, "true")) { + + // cast() should throw exception on invalid inputs when ansi mode is enabled + val df = data.withColumn("converted", col("a").cast(toType)) + checkSparkMaybeThrows(df) match { + case (None, None) => + // neither system threw an exception + case (None, Some(e)) => + // Spark succeeded but Comet failed + throw e + case (Some(e), None) => + // Spark failed but Comet succeeded + fail(s"Comet should have failed with ${e.getCause.getMessage}") + case (Some(sparkException), Some(cometException)) => + // both systems threw an exception so we make sure they are the same + val sparkMessage = + if (sparkException.getCause != null) sparkException.getCause.getMessage + else sparkException.getMessage + val cometMessage = + if (cometException.getCause != null) cometException.getCause.getMessage + else cometException.getMessage + if (CometSparkSessionExtensions.isSpark40Plus) { + // for Spark 4 we expect to sparkException carries the message + assert( + sparkException.getMessage + .replace(".WITH_SUGGESTION] ", "]") + .startsWith(cometMessage)) + } else if (CometSparkSessionExtensions.isSpark34Plus) { + // for Spark 3.4 we expect to reproduce the error message exactly + assert(cometMessage == sparkMessage) } else { - assert(cometMessageModified == sparkMessage) + // for Spark 3.3 we just need to strip the prefix from the Comet message + // before comparing + val cometMessageModified = cometMessage + .replace("[CAST_INVALID_INPUT] ", "") + .replace("[CAST_OVERFLOW] ", "") + .replace("[NUMERIC_VALUE_OUT_OF_RANGE] ", "") + + if (sparkMessage.contains("cannot be represented as")) { + assert(cometMessage.contains("cannot be represented as")) + } else { + assert(cometMessageModified == sparkMessage) + } } - } - } + } - // try_cast() should always return null for invalid inputs - val df2 = - spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a") - checkSparkAnswerAndOperator(df2) + // try_cast() should always return null for invalid inputs + val df2 = + spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a") + checkSparkAnswerAndOperator(df2) + } } } } From a1205bde487695b8cc8213e2bdc3734439d9bf96 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 26 Aug 2024 11:12:31 -0600 Subject: [PATCH 2/5] docs --- docs/source/user-guide/compatibility.md | 5 +++-- .../org/apache/comet/expressions/CometCast.scala | 11 +++++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/docs/source/user-guide/compatibility.md b/docs/source/user-guide/compatibility.md index 739ab6342..5e8499928 100644 --- a/docs/source/user-guide/compatibility.md +++ b/docs/source/user-guide/compatibility.md @@ -136,8 +136,9 @@ The following cast operations are not compatible with Spark for all inputs and a |-|-|-| | integer | decimal | No overflow check | | long | decimal | No overflow check | -| string | float | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. | -| string | double | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. | +| string | float | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. | +| string | double | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. | +| string | decimal | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits | | string | timestamp | Not all valid formats are supported | | binary | string | Only works for binary data representing valid UTF-8 strings | diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index 139a2a1b5..9725c5f21 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -113,12 +113,15 @@ object CometCast { Compatible() case DataTypes.FloatType | DataTypes.DoubleType => // https://github.com/apache/datafusion-comet/issues/326 - Incompatible(Some( - "Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode.")) + Incompatible( + Some( + "Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. " + + "Does not support ANSI mode.")) case _: DecimalType => // https://github.com/apache/datafusion-comet/issues/325 - Incompatible(Some( - "Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits")) + Incompatible( + Some("Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. " + + "Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits")) case DataTypes.DateType => // https://github.com/apache/datafusion-comet/issues/327 Compatible(Some("Only supports years between 262143 BC and 262142 AD")) From dac590485eef31d8174f911b91cbd6998cc0a271 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 26 Aug 2024 11:24:59 -0600 Subject: [PATCH 3/5] update benches --- native/spark-expr/benches/cast_from_string.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/native/spark-expr/benches/cast_from_string.rs b/native/spark-expr/benches/cast_from_string.rs index 51410a68a..056ada2eb 100644 --- a/native/spark-expr/benches/cast_from_string.rs +++ b/native/spark-expr/benches/cast_from_string.rs @@ -31,20 +31,23 @@ fn criterion_benchmark(c: &mut Criterion) { DataType::Int8, EvalMode::Legacy, timezone.clone(), + false, ); let cast_string_to_i16 = Cast::new( expr.clone(), DataType::Int16, EvalMode::Legacy, timezone.clone(), + false, ); let cast_string_to_i32 = Cast::new( expr.clone(), DataType::Int32, EvalMode::Legacy, timezone.clone(), + false, ); - let cast_string_to_i64 = Cast::new(expr, DataType::Int64, EvalMode::Legacy, timezone); + let cast_string_to_i64 = Cast::new(expr, DataType::Int64, EvalMode::Legacy, timezone, false); let mut group = c.benchmark_group("cast_string_to_int"); group.bench_function("cast_string_to_i8", |b| { From a940f1d9573104ba1b318888310d14bbf6ee00f7 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 26 Aug 2024 11:30:11 -0600 Subject: [PATCH 4/5] update benches --- native/spark-expr/benches/cast_numeric.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/native/spark-expr/benches/cast_numeric.rs b/native/spark-expr/benches/cast_numeric.rs index dc0ceea79..15ef1a5a2 100644 --- a/native/spark-expr/benches/cast_numeric.rs +++ b/native/spark-expr/benches/cast_numeric.rs @@ -31,14 +31,16 @@ fn criterion_benchmark(c: &mut Criterion) { DataType::Int8, EvalMode::Legacy, timezone.clone(), + false, ); let cast_i32_to_i16 = Cast::new( expr.clone(), DataType::Int16, EvalMode::Legacy, timezone.clone(), + false, ); - let cast_i32_to_i64 = Cast::new(expr, DataType::Int64, EvalMode::Legacy, timezone); + let cast_i32_to_i64 = Cast::new(expr, DataType::Int64, EvalMode::Legacy, timezone, false); let mut group = c.benchmark_group("cast_int_to_int"); group.bench_function("cast_i32_to_i8", |b| { From 556767c87eb3273276f5183d179698f782973580 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 26 Aug 2024 12:39:09 -0600 Subject: [PATCH 5/5] rust doc --- native/spark-expr/src/cast.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/native/spark-expr/src/cast.rs b/native/spark-expr/src/cast.rs index dfe64a008..ed8cdc2fe 100644 --- a/native/spark-expr/src/cast.rs +++ b/native/spark-expr/src/cast.rs @@ -143,6 +143,7 @@ pub struct Cast { /// session local timezone by an analyzer in Spark. pub timezone: String, + /// Whether to allow casts that are known to be incompatible with Spark pub allow_incompat: bool, }