From 43b67186554229c557c84b83146499e170dc960e Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 9 Jul 2024 19:56:28 +0900 Subject: [PATCH] [SPARK-48831][CONNECT] Make default column name of `cast` compatible with Spark Classic ### What changes were proposed in this pull request? I think there are two issues regarding the default column name of `cast`: 1, It seems unclear that when the name is the input column or `CAST(...)`, e.g. in Spark Classic, ``` scala> spark.range(1).select(col("id").cast("string"), lit(1).cast("string"), col("id").cast("long"), lit(1).cast("long")).printSchema warning: 1 deprecation (since 2.13.3); for details, enable `:setting -deprecation` or `:replay -deprecation` root |-- id: string (nullable = false) |-- CAST(1 AS STRING): string (nullable = false) |-- id: long (nullable = false) |-- CAST(1 AS BIGINT): long (nullable = false) ``` 2, the column name is not consistent between Spark Connect and Spark Classic. This PR aims to resolve the second issue, that is, making default column name of `cast` compatible with Spark Classic, by comparing with classic implementation https://github.com/apache/spark/blob/9cf6dc873ff34412df6256cdc7613eed40716570/sql/core/src/main/scala/org/apache/spark/sql/Column.scala#L1208-L1212 ### Why are the changes needed? the default column name is not consistent with the spark classic ### Does this PR introduce _any_ user-facing change? yes, spark classic: ``` In [2]: spark.range(1).select(sf.lit(b'123').cast("STRING"), sf.lit(123).cast("STRING"), sf.lit(123).cast("LONG"), sf.lit(123).cast("DOUBLE")).show() +-------------------------+-------------------+-------------------+-------------------+ |CAST(X'313233' AS STRING)|CAST(123 AS STRING)|CAST(123 AS BIGINT)|CAST(123 AS DOUBLE)| +-------------------------+-------------------+-------------------+-------------------+ | 123| 123| 123| 123.0| +-------------------------+-------------------+-------------------+-------------------+ ``` spark connect (before): ``` In [3]: spark.range(1).select(sf.lit(b'123').cast("STRING"), sf.lit(123).cast("STRING"), sf.lit(123).cast("LONG"), sf.lit(123).cast("DOUBLE")).show() +---------+---+---+-----+ |X'313233'|123|123| 123| +---------+---+---+-----+ | 123|123|123|123.0| +---------+---+---+-----+ ``` spark connect (after): ``` In [2]: spark.range(1).select(sf.lit(b'123').cast("STRING"), sf.lit(123).cast("STRING"), sf.lit(123).cast("LONG"), sf.lit(123).cast("DOUBLE")).show() +-------------------------+-------------------+-------------------+-------------------+ |CAST(X'313233' AS STRING)|CAST(123 AS STRING)|CAST(123 AS BIGINT)|CAST(123 AS DOUBLE)| +-------------------------+-------------------+-------------------+-------------------+ | 123| 123| 123| 123.0| +-------------------------+-------------------+-------------------+-------------------+ ``` ### How was this patch tested? added test ### Was this patch authored or co-authored using generative AI tooling? no Closes #47249 from zhengruifeng/py_fix_cast. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- .../explain-results/function_atan2.explain | 2 +- .../explain-results/function_base64.explain | 2 +- .../explain-results/function_crc32.explain | 2 +- .../explain-results/function_decode.explain | 2 +- .../explain-results/function_md5.explain | 2 +- .../explain-results/function_sha1.explain | 2 +- .../explain-results/function_sha2.explain | 2 +- .../connect/planner/SparkConnectPlanner.scala | 23 +++++++++++-------- .../planner/SparkConnectProtoSuite.scala | 4 ++-- .../sql/tests/connect/test_connect_column.py | 15 ++++++++++++ 10 files changed, 37 insertions(+), 19 deletions(-) diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_atan2.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_atan2.explain index ebc8f138e7bd0..bf76d33519559 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_atan2.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_atan2.explain @@ -1,2 +1,2 @@ -Project [ATAN2(cast(a#0 as double), b#0) AS ATAN2(a, b)#0] +Project [ATAN2(cast(a#0 as double), b#0) AS ATAN2(CAST(a AS DOUBLE), b)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_base64.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_base64.explain index bc3c6e4bb2bcf..f80f3522190d8 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_base64.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_base64.explain @@ -1,2 +1,2 @@ -Project [base64(cast(g#0 as binary)) AS base64(g)#0] +Project [base64(cast(g#0 as binary)) AS base64(CAST(g AS BINARY))#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_crc32.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_crc32.explain index abd5c1b135b62..3151d121a8b96 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_crc32.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_crc32.explain @@ -1,2 +1,2 @@ -Project [crc32(cast(g#0 as binary)) AS crc32(g)#0L] +Project [crc32(cast(g#0 as binary)) AS crc32(CAST(g AS BINARY))#0L] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_decode.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_decode.explain index c7f2e4cf9c769..ef52e6255a080 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_decode.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_decode.explain @@ -1,2 +1,2 @@ -Project [static_invoke(StringDecode.decode(cast(g#0 as binary), UTF-8, false, false)) AS decode(g, UTF-8)#0] +Project [static_invoke(StringDecode.decode(cast(g#0 as binary), UTF-8, false, false)) AS decode(CAST(g AS BINARY), UTF-8)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_md5.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_md5.explain index 7bbc84785e5e8..c777010f19b06 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_md5.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_md5.explain @@ -1,2 +1,2 @@ -Project [md5(cast(g#0 as binary)) AS md5(g)#0] +Project [md5(cast(g#0 as binary)) AS md5(CAST(g AS BINARY))#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_sha1.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_sha1.explain index 55077f061d720..5ae233d98369c 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_sha1.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_sha1.explain @@ -1,2 +1,2 @@ -Project [sha1(cast(g#0 as binary)) AS sha1(g)#0] +Project [sha1(cast(g#0 as binary)) AS sha1(CAST(g AS BINARY))#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_sha2.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_sha2.explain index 8ed2705cb17cb..f8a059e23ca9f 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_sha2.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_sha2.explain @@ -1,2 +1,2 @@ -Project [sha2(cast(g#0 as binary), 512) AS sha2(g, 512)#0] +Project [sha2(cast(g#0 as binary), 512) AS sha2(CAST(g AS BINARY), 512)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index eaeb1c775ddb6..93a01ea6c5740 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -2178,20 +2178,23 @@ class SparkConnectPlanner( } private def transformCast(cast: proto.Expression.Cast): Expression = { - val dataType = cast.getCastToTypeCase match { + val rawDataType = cast.getCastToTypeCase match { case proto.Expression.Cast.CastToTypeCase.TYPE => transformDataType(cast.getType) case _ => parser.parseDataType(cast.getTypeStr) } - val mode = cast.getEvalMode match { - case proto.Expression.Cast.EvalMode.EVAL_MODE_LEGACY => Some(EvalMode.LEGACY) - case proto.Expression.Cast.EvalMode.EVAL_MODE_ANSI => Some(EvalMode.ANSI) - case proto.Expression.Cast.EvalMode.EVAL_MODE_TRY => Some(EvalMode.TRY) - case _ => None - } - mode match { - case Some(m) => Cast(transformExpression(cast.getExpr), dataType, None, m) - case _ => Cast(transformExpression(cast.getExpr), dataType) + val dataType = CharVarcharUtils.replaceCharVarcharWithStringForCast(rawDataType) + val castExpr = cast.getEvalMode match { + case proto.Expression.Cast.EvalMode.EVAL_MODE_LEGACY => + Cast(transformExpression(cast.getExpr), dataType, None, EvalMode.LEGACY) + case proto.Expression.Cast.EvalMode.EVAL_MODE_ANSI => + Cast(transformExpression(cast.getExpr), dataType, None, EvalMode.ANSI) + case proto.Expression.Cast.EvalMode.EVAL_MODE_TRY => + Cast(transformExpression(cast.getExpr), dataType, None, EvalMode.TRY) + case _ => + Cast(transformExpression(cast.getExpr), dataType) } + castExpr.setTagValue(Cast.USER_SPECIFIED_CAST, ()) + castExpr } private def transformUnresolvedRegex(regex: proto.Expression.UnresolvedRegex): Expression = { diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 7e862bcfc533f..6721555220fe6 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -985,7 +985,7 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { transform(connectTestRelation.observe("my_metric", "id".protoAttr.cast("string")))) }, errorClass = "INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_ATTRIBUTE", - parameters = Map("expr" -> "\"id AS id\"")) + parameters = Map("expr" -> "\"CAST(id AS STRING) AS id\"")) val connectPlan2 = connectTestRelation.observe( @@ -1016,7 +1016,7 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { connectTestRelation.observe(Observation("my_metric"), "id".protoAttr.cast("string")))) }, errorClass = "INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_ATTRIBUTE", - parameters = Map("expr" -> "\"id AS id\"")) + parameters = Map("expr" -> "\"CAST(id AS STRING) AS id\"")) } test("Test RandomSplit") { diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index fbfb4486446ff..c797087aef0ad 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -1046,6 +1046,21 @@ def test_lambda_str_representation(self): ), ) + def test_cast_default_column_name(self): + cdf = self.connect.range(1).select( + CF.lit(b"123").cast("STRING"), + CF.lit(123).cast("STRING"), + CF.lit(123).cast("LONG"), + CF.lit(123).cast("DOUBLE"), + ) + sdf = self.spark.range(1).select( + SF.lit(b"123").cast("STRING"), + SF.lit(123).cast("STRING"), + SF.lit(123).cast("LONG"), + SF.lit(123).cast("DOUBLE"), + ) + self.assertEqual(cdf.columns, sdf.columns) + if __name__ == "__main__": import unittest