From fe8b18b776f52835090fbbc0cc09d465b15f58ce Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Wed, 8 May 2024 18:06:35 +0800 Subject: [PATCH 01/65] [SPARK-48185][SQL] Fix 'symbolic reference class is not accessible: class sun.util.calendar.ZoneInfo' ### What changes were proposed in this pull request? I met the error below while debugging UTs because of loading `sun.util.calendar.ZoneInfo` eagerly. This PR makes the relevant variables lazy. ```log Caused by: java.lang.IllegalAccessException: symbolic reference class is not accessible: class sun.util.calendar.ZoneInfo, from interface org.apache.spark.sql.catalyst.util.SparkDateTimeUtils (unnamed module 65d6b83b) at java.base/java.lang.invoke.MemberName.makeAccessException(MemberName.java:955) at java.base/java.lang.invoke.MethodHandles$Lookup.checkSymbolicClass(MethodHandles.java:3686) at java.base/java.lang.invoke.MethodHandles$Lookup.resolveOrFail(MethodHandles.java:3646) at java.base/java.lang.invoke.MethodHandles$Lookup.findVirtual(MethodHandles.java:2680) at org.apache.spark.sql.catalyst.util.SparkDateTimeUtils.$init$(SparkDateTimeUtils.scala:206) at org.apache.spark.sql.catalyst.util.DateTimeUtils$.(DateTimeUtils.scala:41) ... 82 more ``` ### Why are the changes needed? sun.util.calendar.ZoneInfo is inaccessible in some scenarios. ### Does this PR introduce _any_ user-facing change? Yes, such errors might be delayed from backend-scheduling to job-scheduling ### How was this patch tested? I tested with idea and UT debugging locally ### Was this patch authored or co-authored using generative AI tooling? no Closes #46457 from yaooqinn/SPARK-48185. Authored-by: Kent Yao Signed-off-by: Kent Yao --- .../org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala index 8db8c3cd39d74..0447d813e26a5 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala @@ -198,7 +198,7 @@ trait SparkDateTimeUtils { } private val zoneInfoClassName = "sun.util.calendar.ZoneInfo" - private val getOffsetsByWallHandle = { + private lazy val getOffsetsByWallHandle = { val lookup = MethodHandles.lookup() val classType = SparkClassUtils.classForName(zoneInfoClassName) val methodName = "getOffsetsByWall" From fe3ef20d6418c4ed8965b2d61bf1d32b551e7b53 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Wed, 8 May 2024 19:14:46 +0900 Subject: [PATCH 02/65] [SPARK-48193][INFRA] Make `maven-deploy-plugin` retry 3 times ### What changes were proposed in this pull request? The pr aims to make maven plugin `maven-deploy-plugin` retry `3` times. ### Why are the changes needed? I found that our `the daily scheduled publish snapshot` workflow of GA often failed. https://github.com/apache/spark/actions/workflows/publish_snapshot.yml image I tried to make it as successful as possible by changing the time of retries from `1`(default) to `3`. https://maven.apache.org/plugins/maven-deploy-plugin/deploy-mojo.html#retryFailedDeploymentCount https://maven.apache.org/plugins/maven-deploy-plugin/examples/deploy-network-issues.html#configuring-multiple-tries ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Keep observing `the daily scheduled publish snapshot` workflow of GA. https://github.com/apache/spark/actions/workflows/publish_snapshot.yml ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46471 from panbingkun/SPARK-48193. Authored-by: panbingkun Signed-off-by: Hyukjin Kwon --- pom.xml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pom.xml b/pom.xml index f6f11d94cce32..c72482fd6a41f 100644 --- a/pom.xml +++ b/pom.xml @@ -3384,6 +3384,9 @@ org.apache.maven.plugins maven-deploy-plugin 3.1.2 + + 3 + org.apache.maven.plugins From 1b966d2eb329eed45b258d2134aacc0ea62d75dd Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 8 May 2024 19:18:06 +0900 Subject: [PATCH 03/65] [SPARK-47965][SQL][FOLLOW-UP] Uses `null` as its default value for `OptionalConfigEntry` ### What changes were proposed in this pull request? This PR partially reverts https://github.com/apache/spark/pull/46197 because of the behaviour change below: ```python >>> spark.conf.get("spark.sql.optimizer.excludedRules") '' ``` ### Why are the changes needed? To avoid behaviour change. ### Does this PR introduce _any_ user-facing change? No, the main change has not been released out yet. ### How was this patch tested? Manually as described above. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46472 from HyukjinKwon/SPARK-47965-followup. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../scala/org/apache/spark/internal/config/ConfigEntry.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala index c07f2528ee709..a295ef06a6376 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala @@ -227,7 +227,7 @@ private[spark] class OptionalConfigEntry[T]( prependSeparator, alternatives, s => Some(rawValueConverter(s)), - v => v.map(rawStringConverter).getOrElse(ConfigEntry.UNDEFINED), + v => v.map(rawStringConverter).orNull, doc, isPublic, version From bd896cac168aa5793413058ca706c73705edbf96 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 8 May 2024 19:28:45 +0900 Subject: [PATCH 04/65] Revert "[SPARK-48163][CONNECT][TESTS] Disable `SparkConnectServiceSuite.SPARK-43923: commands send events - get_resources_command`" This reverts commit 56fe185c78a249cf88b1d7e5d1e67444e1b224db. --- .../spark/sql/connect/planner/SparkConnectServiceSuite.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index 59d9750c0fbf4..af18fca9dd216 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -418,14 +418,11 @@ class SparkConnectServiceSuite .setInput( proto.Relation.newBuilder().setSql(proto.SQL.newBuilder().setQuery("select 1")))), None), - // TODO(SPARK-48164) Reenable `commands send events - get_resources_command` - /* ( proto.Command .newBuilder() .setGetResourcesCommand(proto.GetResourcesCommand.newBuilder()), None), - */ ( proto.Command .newBuilder() From d7f69e7003a3c7e7ad22a39e6aaacd183d26d326 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 8 May 2024 18:48:21 +0800 Subject: [PATCH 05/65] [SPARK-48190][PYTHON][PS][TESTS] Introduce a helper function to drop metadata ### What changes were proposed in this pull request? Introduce a helper function to drop metadata ### Why are the changes needed? existing helper function `remove_metadata` in PS doesn't support nested types, so cannot be reused in other places ### Does this PR introduce _any_ user-facing change? no, test only ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #46466 from zhengruifeng/py_drop_meta. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- python/pyspark/pandas/internal.py | 17 +++-------------- .../sql/tests/connect/test_connect_function.py | 11 +++++++++-- python/pyspark/sql/types.py | 13 +++++++++++++ 3 files changed, 25 insertions(+), 16 deletions(-) diff --git a/python/pyspark/pandas/internal.py b/python/pyspark/pandas/internal.py index 767ec9a57f9b5..8ab8d79d56868 100644 --- a/python/pyspark/pandas/internal.py +++ b/python/pyspark/pandas/internal.py @@ -33,6 +33,7 @@ Window, ) from pyspark.sql.types import ( # noqa: F401 + _drop_metadata, BooleanType, DataType, LongType, @@ -761,14 +762,8 @@ def __init__( # in a few tests when using Spark Connect. However, the function works properly. # Therefore, we temporarily perform Spark Connect tests by excluding metadata # until the issue is resolved. - def remove_metadata(struct_field: StructField) -> StructField: - new_struct_field = StructField( - struct_field.name, struct_field.dataType, struct_field.nullable - ) - return new_struct_field - assert all( - remove_metadata(index_field.struct_field) == remove_metadata(struct_field) + _drop_metadata(index_field.struct_field) == _drop_metadata(struct_field) for index_field, struct_field in zip(index_fields, struct_fields) ), (index_fields, struct_fields) else: @@ -795,14 +790,8 @@ def remove_metadata(struct_field: StructField) -> StructField: # in a few tests when using Spark Connect. However, the function works properly. # Therefore, we temporarily perform Spark Connect tests by excluding metadata # until the issue is resolved. - def remove_metadata(struct_field: StructField) -> StructField: - new_struct_field = StructField( - struct_field.name, struct_field.dataType, struct_field.nullable - ) - return new_struct_field - assert all( - remove_metadata(data_field.struct_field) == remove_metadata(struct_field) + _drop_metadata(data_field.struct_field) == _drop_metadata(struct_field) for data_field, struct_field in zip(data_fields, struct_fields) ), (data_fields, struct_fields) else: diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py index 9d4db8cf7d15d..0f0abfd4b8567 100644 --- a/python/pyspark/sql/tests/connect/test_connect_function.py +++ b/python/pyspark/sql/tests/connect/test_connect_function.py @@ -21,7 +21,14 @@ from pyspark.util import is_remote_only from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.sql import SparkSession as PySparkSession -from pyspark.sql.types import StringType, StructType, StructField, ArrayType, IntegerType +from pyspark.sql.types import ( + _drop_metadata, + StringType, + StructType, + StructField, + ArrayType, + IntegerType, +) from pyspark.testing import assertDataFrameEqual from pyspark.testing.pandasutils import PandasOnSparkTestUtils from pyspark.testing.connectutils import ReusedConnectTestCase, should_test_connect @@ -1668,7 +1675,7 @@ def test_nested_lambda_function(self): ) # TODO: 'cdf.schema' has an extra metadata '{'__autoGeneratedAlias': 'true'}' - # self.assertEqual(cdf.schema, sdf.schema) + self.assertEqual(_drop_metadata(cdf.schema), _drop_metadata(sdf.schema)) self.assertEqual(cdf.collect(), sdf.collect()) def test_csv_functions(self): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 48aa3e8e4faba..41be12620fd56 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1569,6 +1569,19 @@ def toJson(self, zone_id: str = "UTC") -> str: _INTERVAL_YEARMONTH = re.compile(r"interval (year|month)( to (year|month))?") +def _drop_metadata(d: Union[DataType, StructField]) -> Union[DataType, StructField]: + assert isinstance(d, (DataType, StructField)) + if isinstance(d, StructField): + return StructField(d.name, _drop_metadata(d.dataType), d.nullable, None) + elif isinstance(d, StructType): + return StructType([cast(StructField, _drop_metadata(f)) for f in d.fields]) + elif isinstance(d, ArrayType): + return ArrayType(_drop_metadata(d.elementType), d.containsNull) + elif isinstance(d, MapType): + return MapType(_drop_metadata(d.keyType), _drop_metadata(d.valueType), d.valueContainsNull) + return d + + def _parse_datatype_string(s: str) -> DataType: """ Parses the given data type string to a :class:`DataType`. The data type string format equals From 003823b39d3504a2a2cffaabbcab1dcf9429fa81 Mon Sep 17 00:00:00 2001 From: Vladimir Golubev Date: Wed, 8 May 2024 20:09:22 +0800 Subject: [PATCH 06/65] [SPARK-48191][SQL] Support UTF-32 for string encode and decode ### What changes were proposed in this pull request? Enable support of UTF-32 ### Why are the changes needed? It already works, so we just need to enable it ### Does this PR introduce _any_ user-facing change? Yes, `decode(..., 'UTF-32')` and `encode(..., 'UTF-32')` will start working ### How was this patch tested? Manually checked in the spark shell ### Was this patch authored or co-authored using generative AI tooling? No Closes #46469 from vladimirg-db/vladimirg-db/support-utf-32-for-string-decode. Authored-by: Vladimir Golubev Signed-off-by: Kent Yao --- docs/sql-migration-guide.md | 2 +- .../sql/catalyst/expressions/stringExpressions.scala | 10 +++++----- .../catalyst/expressions/StringExpressionsSuite.scala | 2 ++ .../analyzer-results/ansi/string-functions.sql.out | 7 +++++++ .../analyzer-results/string-functions.sql.out | 7 +++++++ .../resources/sql-tests/inputs/string-functions.sql | 1 + .../sql-tests/results/ansi/string-functions.sql.out | 8 ++++++++ .../sql-tests/results/string-functions.sql.out | 8 ++++++++ 8 files changed, 39 insertions(+), 6 deletions(-) diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index fa49d6402b180..bd6604cb69c0f 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -32,7 +32,7 @@ license: | - Since Spark 4.0, `spark.sql.hive.metastore` drops the support of Hive prior to 2.0.0 as they require JDK 8 that Spark does not support anymore. Users should migrate to higher versions. - Since Spark 4.0, `spark.sql.parquet.compression.codec` drops the support of codec name `lz4raw`, please use `lz4_raw` instead. - Since Spark 4.0, when overflowing during casting timestamp to byte/short/int under non-ansi mode, Spark will return null instead a wrapping value. -- Since Spark 4.0, the `encode()` and `decode()` functions support only the following charsets 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'. To restore the previous behavior when the function accepts charsets of the current JDK used by Spark, set `spark.sql.legacy.javaCharsets` to `true`. +- Since Spark 4.0, the `encode()` and `decode()` functions support only the following charsets 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16', 'UTF-32'. To restore the previous behavior when the function accepts charsets of the current JDK used by Spark, set `spark.sql.legacy.javaCharsets` to `true`. - Since Spark 4.0, the legacy datetime rebasing SQL configs with the prefix `spark.sql.legacy` are removed. To restore the previous behavior, use the following configs: - `spark.sql.parquet.int96RebaseModeInWrite` instead of `spark.sql.legacy.parquet.int96RebaseModeInWrite` - `spark.sql.parquet.datetimeRebaseModeInWrite` instead of `spark.sql.legacy.parquet.datetimeRebaseModeInWrite` diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index c2ea17de19533..0bdd7930b0bf9 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -2646,7 +2646,7 @@ object Decode { arguments = """ Arguments: * bin - a binary expression to decode - * charset - one of the charsets 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16' to decode `bin` into a STRING. It is case insensitive. + * charset - one of the charsets 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16', 'UTF-32' to decode `bin` into a STRING. It is case insensitive. """, examples = """ Examples: @@ -2690,7 +2690,7 @@ case class Decode(params: Seq[Expression], replacement: Expression) arguments = """ Arguments: * bin - a binary expression to decode - * charset - one of the charsets 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16' to decode `bin` into a STRING. It is case insensitive. + * charset - one of the charsets 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16', 'UTF-32' to decode `bin` into a STRING. It is case insensitive. """, since = "1.5.0", group = "string_funcs") @@ -2707,7 +2707,7 @@ case class StringDecode(bin: Expression, charset: Expression, legacyCharsets: Bo override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType, StringTypeAnyCollation) private val supportedCharsets = Set( - "US-ASCII", "ISO-8859-1", "UTF-8", "UTF-16BE", "UTF-16LE", "UTF-16") + "US-ASCII", "ISO-8859-1", "UTF-8", "UTF-16BE", "UTF-16LE", "UTF-16", "UTF-32") protected override def nullSafeEval(input1: Any, input2: Any): Any = { val fromCharset = input2.asInstanceOf[UTF8String].toString @@ -2762,7 +2762,7 @@ object StringDecode { arguments = """ Arguments: * str - a string expression - * charset - one of the charsets 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16' to encode `str` into a BINARY. It is case insensitive. + * charset - one of the charsets 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16', 'UTF-32' to encode `str` into a BINARY. It is case insensitive. """, examples = """ Examples: @@ -2785,7 +2785,7 @@ case class Encode(str: Expression, charset: Expression, legacyCharsets: Boolean) Seq(StringTypeAnyCollation, StringTypeAnyCollation) private val supportedCharsets = Set( - "US-ASCII", "ISO-8859-1", "UTF-8", "UTF-16BE", "UTF-16LE", "UTF-16") + "US-ASCII", "ISO-8859-1", "UTF-8", "UTF-16BE", "UTF-16LE", "UTF-16", "UTF-32") protected override def nullSafeEval(input1: Any, input2: Any): Any = { val toCharset = input2.asInstanceOf[UTF8String].toString diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 0fcceef392389..51de44d8dfd98 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -489,6 +489,8 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // non ascii characters are not allowed in the code, so we disable the scalastyle here. checkEvaluation( StringDecode(Encode(Literal("大千世界"), Literal("UTF-16LE")), Literal("UTF-16LE")), "大千世界") + checkEvaluation( + StringDecode(Encode(Literal("大千世界"), Literal("UTF-32")), Literal("UTF-32")), "大千世界") checkEvaluation( StringDecode(Encode(a, Literal("utf-8")), Literal("utf-8")), "大千世界", create_row("大千世界")) checkEvaluation( diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/string-functions.sql.out index 7ffd3cbd8bac6..c36dec0b105d7 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/string-functions.sql.out @@ -750,6 +750,13 @@ Project [decode(encode(abc, utf-8, false), utf-8) AS decode(encode(abc, utf-8), +- OneRowRelation +-- !query +select decode(encode('大千世界', 'utf-32'), 'utf-32') +-- !query analysis +Project [decode(encode(大千世界, utf-32, false), utf-32) AS decode(encode(大千世界, utf-32), utf-32)#x] ++- OneRowRelation + + -- !query select decode(1, 1, 'Southlake') -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/string-functions.sql.out index 7ffd3cbd8bac6..c36dec0b105d7 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/string-functions.sql.out @@ -750,6 +750,13 @@ Project [decode(encode(abc, utf-8, false), utf-8) AS decode(encode(abc, utf-8), +- OneRowRelation +-- !query +select decode(encode('大千世界', 'utf-32'), 'utf-32') +-- !query analysis +Project [decode(encode(大千世界, utf-32, false), utf-32) AS decode(encode(大千世界, utf-32), utf-32)#x] ++- OneRowRelation + + -- !query select decode(1, 1, 'Southlake') -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index 64ea6e655d0b5..733720a7e21b2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -131,6 +131,7 @@ select encode(scol, ecol) from values('hello', 'Windows-xxx') as t(scol, ecol); select decode(); select decode(encode('abc', 'utf-8')); select decode(encode('abc', 'utf-8'), 'utf-8'); +select decode(encode('大千世界', 'utf-32'), 'utf-32'); select decode(1, 1, 'Southlake'); select decode(2, 1, 'Southlake'); select decode(2, 1, 'Southlake', 2, 'San Francisco', 3, 'New Jersey', 4, 'Seattle', 'Non domestic'); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out index 8096cef266ec4..09d4f8892fa48 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out @@ -961,6 +961,14 @@ struct abc +-- !query +select decode(encode('大千世界', 'utf-32'), 'utf-32') +-- !query schema +struct +-- !query output +大千世界 + + -- !query select decode(1, 1, 'Southlake') -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 91ad830dd3d7a..506524840f107 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -893,6 +893,14 @@ struct abc +-- !query +select decode(encode('大千世界', 'utf-32'), 'utf-32') +-- !query schema +struct +-- !query output +大千世界 + + -- !query select decode(1, 1, 'Southlake') -- !query schema From 8950add773e63a910900f796950a6a58e40a8577 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 8 May 2024 20:11:24 +0800 Subject: [PATCH 07/65] [SPARK-48188][SQL] Consistently use normalized plan for cache ### What changes were proposed in this pull request? We must consistently use normalized plans for cache filling and lookup, or inconsistency will lead to cache misses. To guarantee this, this PR makes `CacheManager` the central place to do plan normalization, so that callers don't need to care about it. Now most APIs in `CacheManager` take either `Dataset` or `LogicalPlan`. For `Dataset`, we get the normalized plan directly. For `LogicalPlan`, we normalize it before further use. The caller side should pass `Dataset` when invoking `CacheManager`, if it already creates `Dataset`. This is to reduce the impact, as extra creation of `Dataset` may have perf issues or introduce unexpected analysis exception. ### Why are the changes needed? Avoid unnecessary cache misses for users who add custom normalization rules ### Does this PR introduce _any_ user-facing change? No, perf only ### How was this patch tested? existing tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #46465 from cloud-fan/cache. Authored-by: Wenchen Fan Signed-off-by: Kent Yao --- .../scala/org/apache/spark/sql/Dataset.scala | 3 +- .../spark/sql/execution/CacheManager.scala | 160 +++++++++++------- .../spark/sql/execution/QueryExecution.scala | 37 ++-- .../command/AnalyzeColumnCommand.scala | 4 +- .../sql/execution/command/CommandUtils.scala | 2 +- .../datasources/v2/CacheTableExec.scala | 30 ++-- .../datasources/v2/DataSourceV2Strategy.scala | 2 +- .../spark/sql/internal/CatalogImpl.scala | 5 +- .../apache/spark/sql/CachedTableSuite.scala | 2 +- .../apache/spark/sql/test/SQLTestUtils.scala | 3 +- .../spark/sql/hive/CachedTableSuite.scala | 9 +- 11 files changed, 150 insertions(+), 107 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 18c9704afdf83..3e843e64ebbf6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3904,8 +3904,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def unpersist(blocking: Boolean): this.type = { - sparkSession.sharedState.cacheManager.uncacheQuery( - sparkSession, logicalPlan, cascade = false, blocking) + sparkSession.sharedState.cacheManager.uncacheQuery(this, cascade = false, blocking) this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index ae99873a9f774..b96f257e6b5b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, SubqueryExpression} import org.apache.spark.sql.catalyst.optimizer.EliminateResolvedHint -import org.apache.spark.sql.catalyst.plans.logical.{IgnoreCachedData, LogicalPlan, ResolvedHint, SubqueryAlias, View} +import org.apache.spark.sql.catalyst.plans.logical.{IgnoreCachedData, LogicalPlan, ResolvedHint, View} import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -38,7 +38,10 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK /** Holds a cached logical plan and its data */ -case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation) { +case class CachedData( + // A normalized resolved plan (See QueryExecution#normalized). + plan: LogicalPlan, + cachedRepresentation: InMemoryRelation) { override def toString: String = s""" |CachedData( @@ -53,7 +56,9 @@ case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation) * InMemoryRelation. This relation is automatically substituted query plans that return the * `sameResult` as the originally cached query. * - * Internal to Spark SQL. + * Internal to Spark SQL. All its public APIs take analyzed plans and will normalize them before + * further usage, or take [[Dataset]] and get its normalized plan. See `QueryExecution.normalize` + * for more details about plan normalization. */ class CacheManager extends Logging with AdaptiveSparkPlanHelper { @@ -77,41 +82,43 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { cachedData.isEmpty } + // Test-only + def cacheQuery(query: Dataset[_]): Unit = { + cacheQuery(query, tableName = None, storageLevel = MEMORY_AND_DISK) + } + /** * Caches the data produced by the logical representation of the given [[Dataset]]. - * Unlike `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because - * recomputing the in-memory columnar representation of the underlying table is expensive. */ def cacheQuery( query: Dataset[_], - tableName: Option[String] = None, - storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = { - cacheQuery(query.sparkSession, query.queryExecution.normalized, tableName, storageLevel) + tableName: Option[String], + storageLevel: StorageLevel): Unit = { + cacheQueryInternal(query.sparkSession, query.queryExecution.normalized, tableName, storageLevel) } /** - * Caches the data produced by the given [[LogicalPlan]]. - * Unlike `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because - * recomputing the in-memory columnar representation of the underlying table is expensive. + * Caches the data produced by the given [[LogicalPlan]]. The given plan will be normalized + * before being used further. */ def cacheQuery( spark: SparkSession, planToCache: LogicalPlan, - tableName: Option[String]): Unit = { - cacheQuery(spark, planToCache, tableName, MEMORY_AND_DISK) + tableName: Option[String], + storageLevel: StorageLevel): Unit = { + val normalized = QueryExecution.normalize(spark, planToCache) + cacheQueryInternal(spark, normalized, tableName, storageLevel) } - /** - * Caches the data produced by the given [[LogicalPlan]]. - */ - def cacheQuery( + // The `planToCache` should have been normalized. + private def cacheQueryInternal( spark: SparkSession, planToCache: LogicalPlan, tableName: Option[String], storageLevel: StorageLevel): Unit = { if (storageLevel == StorageLevel.NONE) { // Do nothing for StorageLevel.NONE since it will not actually cache any data. - } else if (lookupCachedData(planToCache).nonEmpty) { + } else if (lookupCachedDataInternal(planToCache).nonEmpty) { logWarning("Asked to cache already cached data.") } else { val sessionWithConfigsOff = getOrCloneSessionWithConfigsOff(spark) @@ -124,7 +131,7 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { } this.synchronized { - if (lookupCachedData(planToCache).nonEmpty) { + if (lookupCachedDataInternal(planToCache).nonEmpty) { logWarning("Data has already been cached.") } else { val cd = CachedData(planToCache, inMemoryRelation) @@ -138,38 +145,64 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { /** * Un-cache the given plan or all the cache entries that refer to the given plan. - * @param query The [[Dataset]] to be un-cached. - * @param cascade If true, un-cache all the cache entries that refer to the given - * [[Dataset]]; otherwise un-cache the given [[Dataset]] only. + * + * @param query The [[Dataset]] to be un-cached. + * @param cascade If true, un-cache all the cache entries that refer to the given + * [[Dataset]]; otherwise un-cache the given [[Dataset]] only. + * @param blocking Whether to block until all blocks are deleted. */ + def uncacheQuery( + query: Dataset[_], + cascade: Boolean, + blocking: Boolean): Unit = { + uncacheQueryInternal(query.sparkSession, query.queryExecution.normalized, cascade, blocking) + } + + // An overload to provide default value for the `blocking` parameter. def uncacheQuery( query: Dataset[_], cascade: Boolean): Unit = { - uncacheQuery(query.sparkSession, query.queryExecution.normalized, cascade) + uncacheQuery(query, cascade, blocking = false) } /** * Un-cache the given plan or all the cache entries that refer to the given plan. - * @param spark The Spark session. - * @param plan The plan to be un-cached. - * @param cascade If true, un-cache all the cache entries that refer to the given - * plan; otherwise un-cache the given plan only. - * @param blocking Whether to block until all blocks are deleted. + * + * @param spark The Spark session. + * @param plan The plan to be un-cached. + * @param cascade If true, un-cache all the cache entries that refer to the given + * plan; otherwise un-cache the given plan only. + * @param blocking Whether to block until all blocks are deleted. */ def uncacheQuery( spark: SparkSession, plan: LogicalPlan, cascade: Boolean, - blocking: Boolean = false): Unit = { - uncacheQuery(spark, _.sameResult(plan), cascade, blocking) + blocking: Boolean): Unit = { + val normalized = QueryExecution.normalize(spark, plan) + uncacheQueryInternal(spark, normalized, cascade, blocking) + } + + // An overload to provide default value for the `blocking` parameter. + def uncacheQuery( + spark: SparkSession, + plan: LogicalPlan, + cascade: Boolean): Unit = { + uncacheQuery(spark, plan, cascade, blocking = false) + } + + // The `plan` should have been normalized. + private def uncacheQueryInternal( + spark: SparkSession, + plan: LogicalPlan, + cascade: Boolean, + blocking: Boolean): Unit = { + uncacheByCondition(spark, _.sameResult(plan), cascade, blocking) } def uncacheTableOrView(spark: SparkSession, name: Seq[String], cascade: Boolean): Unit = { - uncacheQuery( - spark, - isMatchedTableOrView(_, name, spark.sessionState.conf), - cascade, - blocking = false) + uncacheByCondition( + spark, isMatchedTableOrView(_, name, spark.sessionState.conf), cascade, blocking = false) } private def isMatchedTableOrView(plan: LogicalPlan, name: Seq[String], conf: SQLConf): Boolean = { @@ -178,28 +211,24 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { } plan match { - case SubqueryAlias(ident, LogicalRelation(_, _, Some(catalogTable), _)) => - val v1Ident = catalogTable.identifier - isSameName(ident.qualifier :+ ident.name) && isSameName(v1Ident.nameParts) + case LogicalRelation(_, _, Some(catalogTable), _) => + isSameName(catalogTable.identifier.nameParts) - case SubqueryAlias(ident, DataSourceV2Relation(_, _, Some(catalog), Some(v2Ident), _)) => + case DataSourceV2Relation(_, _, Some(catalog), Some(v2Ident), _) => import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper - isSameName(ident.qualifier :+ ident.name) && - isSameName(v2Ident.toQualifiedNameParts(catalog)) + isSameName(v2Ident.toQualifiedNameParts(catalog)) - case SubqueryAlias(ident, View(catalogTable, _, _)) => - val v1Ident = catalogTable.identifier - isSameName(ident.qualifier :+ ident.name) && isSameName(v1Ident.nameParts) + case View(catalogTable, _, _) => + isSameName(catalogTable.identifier.nameParts) - case SubqueryAlias(ident, HiveTableRelation(catalogTable, _, _, _, _)) => - val v1Ident = catalogTable.identifier - isSameName(ident.qualifier :+ ident.name) && isSameName(v1Ident.nameParts) + case HiveTableRelation(catalogTable, _, _, _, _) => + isSameName(catalogTable.identifier.nameParts) case _ => false } } - def uncacheQuery( + private def uncacheByCondition( spark: SparkSession, isMatchedPlan: LogicalPlan => Boolean, cascade: Boolean, @@ -252,10 +281,12 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { } /** - * Tries to re-cache all the cache entries that refer to the given plan. + * Tries to re-cache all the cache entries that refer to the given plan. The given plan will be + * normalized before being used further. */ def recacheByPlan(spark: SparkSession, plan: LogicalPlan): Unit = { - recacheByCondition(spark, _.plan.exists(_.sameResult(plan))) + val normalized = QueryExecution.normalize(spark, plan) + recacheByCondition(spark, _.plan.exists(_.sameResult(normalized))) } /** @@ -278,7 +309,7 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { } val recomputedPlan = cd.copy(cachedRepresentation = newCache) this.synchronized { - if (lookupCachedData(recomputedPlan.plan).nonEmpty) { + if (lookupCachedDataInternal(recomputedPlan.plan).nonEmpty) { logWarning("While recaching, data was already added to cache.") } else { cachedData = recomputedPlan +: cachedData @@ -289,13 +320,23 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { } } - /** Optionally returns cached data for the given [[Dataset]] */ + /** + * Optionally returns cached data for the given [[Dataset]] + */ def lookupCachedData(query: Dataset[_]): Option[CachedData] = { - lookupCachedData(query.queryExecution.normalized) + lookupCachedDataInternal(query.queryExecution.normalized) } - /** Optionally returns cached data for the given [[LogicalPlan]]. */ - def lookupCachedData(plan: LogicalPlan): Option[CachedData] = { + /** + * Optionally returns cached data for the given [[LogicalPlan]]. The given plan will be normalized + * before being used further. + */ + def lookupCachedData(session: SparkSession, plan: LogicalPlan): Option[CachedData] = { + val normalized = QueryExecution.normalize(session, plan) + lookupCachedDataInternal(normalized) + } + + private def lookupCachedDataInternal(plan: LogicalPlan): Option[CachedData] = { val result = cachedData.find(cd => plan.sameResult(cd.plan)) if (result.isDefined) { CacheManager.logCacheOperation(log"Dataframe cache hit for input plan:" + @@ -305,13 +346,16 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { result } - /** Replaces segments of the given logical plan with cached versions where possible. */ - def useCachedData(plan: LogicalPlan): LogicalPlan = { + /** + * Replaces segments of the given logical plan with cached versions where possible. The input + * plan must be normalized. + */ + private[sql] def useCachedData(plan: LogicalPlan): LogicalPlan = { val newPlan = plan transformDown { case command: IgnoreCachedData => command case currentFragment => - lookupCachedData(currentFragment).map { cached => + lookupCachedDataInternal(currentFragment).map { cached => // After cache lookup, we should still keep the hints from the input plan. val hints = EliminateResolvedHint.extractHintsFromPlan(currentFragment)._2 val cachedPlan = cached.cachedRepresentation.withOutput(currentFragment.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index d04d8dc2cd7fd..357484ca19df2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -134,19 +134,7 @@ class QueryExecution( // The plan that has been normalized by custom rules, so that it's more likely to hit cache. lazy val normalized: LogicalPlan = { - val normalizationRules = sparkSession.sessionState.planNormalizationRules - if (normalizationRules.isEmpty) { - commandExecuted - } else { - val planChangeLogger = new PlanChangeLogger[LogicalPlan]() - val normalized = normalizationRules.foldLeft(commandExecuted) { (p, rule) => - val result = rule.apply(p) - planChangeLogger.logRule(rule.ruleName, p, result) - result - } - planChangeLogger.logBatch("Plan Normalization", commandExecuted, normalized) - normalized - } + QueryExecution.normalize(sparkSession, commandExecuted, Some(tracker)) } lazy val withCachedData: LogicalPlan = sparkSession.withActive { @@ -613,4 +601,27 @@ object QueryExecution { case e: Throwable => throw toInternalError(msg, e) } } + + def normalize( + session: SparkSession, + plan: LogicalPlan, + tracker: Option[QueryPlanningTracker] = None): LogicalPlan = { + val normalizationRules = session.sessionState.planNormalizationRules + if (normalizationRules.isEmpty) { + plan + } else { + val planChangeLogger = new PlanChangeLogger[LogicalPlan]() + val normalized = normalizationRules.foldLeft(plan) { (p, rule) => + val startTime = System.nanoTime() + val result = rule.apply(p) + val runTime = System.nanoTime() - startTime + val effective = !result.fastEquals(p) + tracker.foreach(_.recordRuleInvocation(rule.ruleName, runTime, effective)) + planChangeLogger.logRule(rule.ruleName, p, result) + result + } + planChangeLogger.logBatch("Plan Normalization", plan, normalized) + normalized + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 299f41eb55e17..7b0ce3e59263f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -61,8 +61,8 @@ case class AnalyzeColumnCommand( private def analyzeColumnInCachedData(plan: LogicalPlan, sparkSession: SparkSession): Boolean = { val cacheManager = sparkSession.sharedState.cacheManager - val planToLookup = sparkSession.sessionState.executePlan(plan).analyzed - cacheManager.lookupCachedData(planToLookup).map { cachedData => + val df = Dataset.ofRows(sparkSession, plan) + cacheManager.lookupCachedData(df).map { cachedData => val columnsToAnalyze = getColumnsToAnalyze( tableIdent, cachedData.cachedRepresentation, columnNames, allColumns) cacheManager.analyzeColumnCacheQuery(sparkSession, cachedData, columnsToAnalyze) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala index d7c5df151bf12..7acd1cb0852b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -240,7 +240,7 @@ object CommandUtils extends Logging { // Analyzes a catalog view if the view is cached val table = sparkSession.table(tableIdent.quotedString) val cacheManager = sparkSession.sharedState.cacheManager - if (cacheManager.lookupCachedData(table.logicalPlan).isDefined) { + if (cacheManager.lookupCachedData(table).isDefined) { if (!noScan) { // To collect table stats, materializes an underlying columnar RDD table.count() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CacheTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CacheTableExec.scala index fc8a40f885450..56c44a1256815 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CacheTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CacheTableExec.scala @@ -21,9 +21,9 @@ import java.util.Locale import org.apache.spark.internal.LogKeys.OPTIONS import org.apache.spark.internal.MDC -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.Dataset import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.LocalTempView +import org.apache.spark.sql.catalyst.analysis.{LocalTempView, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap @@ -34,7 +34,6 @@ import org.apache.spark.storage.StorageLevel trait BaseCacheTableExec extends LeafV2CommandExec { def relationName: String def planToCache: LogicalPlan - def dataFrameForCachedPlan: DataFrame def isLazy: Boolean def options: Map[String, String] @@ -49,15 +48,12 @@ trait BaseCacheTableExec extends LeafV2CommandExec { logWarning(log"Invalid options: ${MDC(OPTIONS, withoutStorageLevel.mkString(", "))}") } - session.sharedState.cacheManager.cacheQuery( - session, - planToCache, - Some(relationName), - storageLevel) + val df = Dataset.ofRows(session, planToCache) + session.sharedState.cacheManager.cacheQuery(df, Some(relationName), storageLevel) if (!isLazy) { // Performs eager caching. - dataFrameForCachedPlan.count() + df.count() } Seq.empty @@ -74,10 +70,6 @@ case class CacheTableExec( override lazy val relationName: String = multipartIdentifier.quoted override lazy val planToCache: LogicalPlan = relation - - override lazy val dataFrameForCachedPlan: DataFrame = { - Dataset.ofRows(session, planToCache) - } } case class CacheTableAsSelectExec( @@ -89,7 +81,10 @@ case class CacheTableAsSelectExec( referredTempFunctions: Seq[String]) extends BaseCacheTableExec { override lazy val relationName: String = tempViewName - override lazy val planToCache: LogicalPlan = { + override def planToCache: LogicalPlan = UnresolvedRelation(Seq(tempViewName)) + + override def run(): Seq[InternalRow] = { + // CACHE TABLE AS TABLE creates a temp view and caches the temp view. CreateViewCommand( name = TableIdentifier(tempViewName), userSpecifiedColumns = Nil, @@ -103,12 +98,7 @@ case class CacheTableAsSelectExec( isAnalyzed = true, referredTempFunctions = referredTempFunctions ).run(session) - - dataFrameForCachedPlan.logicalPlan - } - - override lazy val dataFrameForCachedPlan: DataFrame = { - session.table(tempViewName) + super.run() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 0d926dcd99c4a..7a668b75c3c73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -83,7 +83,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat // given table, the cache's storage level is returned. private def invalidateTableCache(r: ResolvedTable)(): Option[StorageLevel] = { val v2Relation = DataSourceV2Relation.create(r.table, Some(r.catalog), Some(r.identifier)) - val cache = session.sharedState.cacheManager.lookupCachedData(v2Relation) + val cache = session.sharedState.cacheManager.lookupCachedData(session, v2Relation) session.sharedState.cacheManager.uncacheQuery(session, v2Relation, cascade = true) if (cache.isDefined) { val cacheLevel = cache.get.cachedRepresentation.cacheBuilder.storageLevel diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index df7c4ab1a0c7d..3e20a23a0a066 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -734,9 +734,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { // same way as how a permanent view is handled. This also avoids a potential issue where a // dependent view becomes invalid because of the above while its data is still cached. val viewText = viewDef.desc.viewText - val plan = sparkSession.sessionState.executePlan(viewDef) - sparkSession.sharedState.cacheManager.uncacheQuery( - sparkSession, plan.analyzed, cascade = viewText.isDefined) + val df = Dataset.ofRows(sparkSession, viewDef) + sparkSession.sharedState.cacheManager.uncacheQuery(df, cascade = viewText.isDefined) } catch { case NonFatal(_) => // ignore } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 0ad9ceefc4196..d023fb82185a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -1107,7 +1107,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils assert(queryStats1.map(_._1.name).isEmpty) val cacheManager = spark.sharedState.cacheManager - val cachedData = cacheManager.lookupCachedData(query().logicalPlan) + val cachedData = cacheManager.lookupCachedData(query()) assert(cachedData.isDefined) val queryAttrs = cachedData.get.plan.output assert(queryAttrs.size === 3) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 7da2bb47038ed..5fbf379644f6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -339,8 +339,7 @@ private[sql] trait SQLTestUtilsBase val tableIdent = spark.sessionState.sqlParser.parseTableIdentifier(tableName) val cascade = !spark.sessionState.catalog.isTempView(tableIdent) spark.sharedState.cacheManager.uncacheQuery( - spark, - spark.table(tableName).logicalPlan, + spark.table(tableName), cascade = cascade, blocking = true) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 89fe10d5c4bd9..d7918f8cbf4f0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -335,9 +335,10 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto options = Map.empty)(sparkSession = spark) val plan = LogicalRelation(relation, tableMeta) - spark.sharedState.cacheManager.cacheQuery(Dataset.ofRows(spark, plan)) + val df = Dataset.ofRows(spark, plan) + spark.sharedState.cacheManager.cacheQuery(df) - assert(spark.sharedState.cacheManager.lookupCachedData(plan).isDefined) + assert(spark.sharedState.cacheManager.lookupCachedData(df).isDefined) val sameCatalog = new CatalogFileIndex(spark, tableMeta, 0) val sameRelation = HadoopFsRelation( @@ -347,9 +348,9 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto bucketSpec = None, fileFormat = new ParquetFileFormat(), options = Map.empty)(sparkSession = spark) - val samePlan = LogicalRelation(sameRelation, tableMeta) + val samePlanDf = Dataset.ofRows(spark, LogicalRelation(sameRelation, tableMeta)) - assert(spark.sharedState.cacheManager.lookupCachedData(samePlan).isDefined) + assert(spark.sharedState.cacheManager.lookupCachedData(samePlanDf).isDefined) } } From 8d7081639ab47996e357a0a968ca74661795da85 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Wed, 8 May 2024 20:57:08 +0800 Subject: [PATCH 08/65] [SPARK-48161][SQL] Add collation support for JSON expressions ### What changes were proposed in this pull request? Introduce collation awareness for JSON expressions: get_json_object, json_tuple, from_json, to_json, json_array_length, json_object_keys. ### Why are the changes needed? Add collation support for JSON expressions in Spark. ### Does this PR introduce _any_ user-facing change? Yes, users should now be able to use collated strings within arguments for JSON functions: get_json_object, json_tuple, from_json, to_json, json_array_length, json_object_keys. ### How was this patch tested? E2e sql tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46462 from uros-db/json-expressions. Authored-by: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Signed-off-by: Wenchen Fan --- .../expressions/complexTypeCreator.scala | 5 +- .../expressions/jsonExpressions.scala | 20 +- .../sql/catalyst/json/JacksonParser.scala | 2 +- .../sql/CollationSQLExpressionsSuite.scala | 198 ++++++++++++++++++ 4 files changed, 213 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index c38b6cea9a0a5..4c0d005340606 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -374,7 +374,8 @@ object CreateStruct { // We should always use the last part of the column name (`c` in the above example) as the // alias name inside CreateNamedStruct. case (u: UnresolvedAttribute, _) => Seq(Literal(u.nameParts.last), u) - case (u @ UnresolvedExtractValue(_, e: Literal), _) if e.dataType == StringType => Seq(e, u) + case (u @ UnresolvedExtractValue(_, e: Literal), _) if e.dataType.isInstanceOf[StringType] => + Seq(e, u) case (a: Alias, _) => Seq(Literal(a.name), a) case (e: NamedExpression, _) if e.resolved => Seq(Literal(e.name), e) case (e: NamedExpression, _) => Seq(NamePlaceholder, e) @@ -465,7 +466,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression with toSQLId(prettyName), Seq("2n (n > 0)"), children.length ) } else { - val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) + val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType.isInstanceOf[StringType]) if (invalidNames.nonEmpty) { DataTypeMismatch( errorSubClass = "CREATE_NAMED_STRUCT_WITHOUT_FOLDABLE_STRING", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 35e30ceb45cb5..8258bb389e2da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{JSON_TO_STRUCT, TreePatt import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.types.StringTypeAnyCollation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{UTF8String, VariantVal} import org.apache.spark.util.Utils @@ -132,8 +133,9 @@ case class GetJsonObject(json: Expression, path: Expression) override def left: Expression = json override def right: Expression = path - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) - override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeAnyCollation, StringTypeAnyCollation) + override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = true override def prettyName: String = "get_json_object" @@ -477,7 +479,7 @@ case class JsonTuple(children: Seq[Expression]) @transient private lazy val constantFields: Int = foldableFieldNames.count(_ != null) override def elementSchema: StructType = StructType(fieldExpressions.zipWithIndex.map { - case (_, idx) => StructField(s"c$idx", StringType, nullable = true) + case (_, idx) => StructField(s"c$idx", children.head.dataType, nullable = true) }) override def prettyName: String = "json_tuple" @@ -487,7 +489,7 @@ case class JsonTuple(children: Seq[Expression]) throw QueryCompilationErrors.wrongNumArgsError( toSQLId(prettyName), Seq("> 1"), children.length ) - } else if (children.forall(child => StringType.acceptsType(child.dataType))) { + } else if (children.forall(child => StringTypeAnyCollation.acceptsType(child.dataType))) { TypeCheckResult.TypeCheckSuccess } else { DataTypeMismatch( @@ -722,7 +724,7 @@ case class JsonToStructs( converter(parser.parse(json.asInstanceOf[UTF8String])) } - override def inputTypes: Seq[AbstractDataType] = StringType :: Nil + override def inputTypes: Seq[AbstractDataType] = StringTypeAnyCollation :: Nil override def sql: String = schema match { case _: MapType => "entries" @@ -824,7 +826,7 @@ case class StructsToJson( } } - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def checkInputDataTypes(): TypeCheckResult = inputSchema match { case dt @ (_: StructType | _: MapType | _: ArrayType | _: VariantType) => @@ -957,7 +959,7 @@ case class SchemaOfJson( case class LengthOfJsonArray(child: Expression) extends UnaryExpression with CodegenFallback with ExpectsInputTypes { - override def inputTypes: Seq[DataType] = Seq(StringType) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) override def dataType: DataType = IntegerType override def nullable: Boolean = true override def prettyName: String = "json_array_length" @@ -1030,8 +1032,8 @@ case class LengthOfJsonArray(child: Expression) extends UnaryExpression case class JsonObjectKeys(child: Expression) extends UnaryExpression with CodegenFallback with ExpectsInputTypes { - override def inputTypes: Seq[DataType] = Seq(StringType) - override def dataType: DataType = ArrayType(StringType) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def dataType: DataType = ArrayType(SQLConf.get.defaultStringType) override def nullable: Boolean = true override def prettyName: String = "json_object_keys" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 3c42f72fa6b6c..848c20ee36bef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -275,7 +275,7 @@ class JacksonParser( } } - case StringType => + case _: StringType => (parser: JsonParser) => parseJsonToken[UTF8String](parser, dataType) { case VALUE_STRING => UTF8String.fromString(parser.getText) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index 4314ff97a3cf3..19f34ec15aa07 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -460,6 +460,204 @@ class CollationSQLExpressionsSuite }) } + test("Support GetJsonObject json expression with collation") { + case class GetJsonObjectTestCase( + input: String, + path: String, + collationName: String, + result: String + ) + + val testCases = Seq( + GetJsonObjectTestCase("{\"a\":\"b\"}", "$.a", "UTF8_BINARY", "b"), + GetJsonObjectTestCase("{\"A\":\"1\"}", "$.A", "UTF8_BINARY_LCASE", "1"), + GetJsonObjectTestCase("{\"x\":true}", "$.x", "UNICODE", "true"), + GetJsonObjectTestCase("{\"X\":1}", "$.X", "UNICODE_CI", "1") + ) + + // Supported collations + testCases.foreach(t => { + val query = + s""" + |SELECT get_json_object('${t.input}', '${t.path}') + |""".stripMargin + // Result & data type + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + checkAnswer(testQuery, Row(t.result)) + val dataType = StringType(t.collationName) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + }) + } + + test("Support JsonTuple json expression with collation") { + case class JsonTupleTestCase( + input: String, + names: String, + collationName: String, + result: Row + ) + + val testCases = Seq( + JsonTupleTestCase("{\"a\":1, \"b\":2}", "'a', 'b'", "UTF8_BINARY", + Row("1", "2")), + JsonTupleTestCase("{\"A\":\"3\", \"B\":\"4\"}", "'A', 'B'", "UTF8_BINARY_LCASE", + Row("3", "4")), + JsonTupleTestCase("{\"x\":true, \"y\":false}", "'x', 'y'", "UNICODE", + Row("true", "false")), + JsonTupleTestCase("{\"X\":null, \"Y\":null}", "'X', 'Y'", "UNICODE_CI", + Row(null, null)) + ) + + // Supported collations + testCases.foreach(t => { + val query = + s""" + |SELECT json_tuple('${t.input}', ${t.names}) + |""".stripMargin + // Result & data type + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + checkAnswer(testQuery, t.result) + val dataType = StringType(t.collationName) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + }) + } + + test("Support JsonToStructs json expression with collation") { + case class JsonToStructsTestCase( + input: String, + schema: String, + collationName: String, + result: Row + ) + + val testCases = Seq( + JsonToStructsTestCase("{\"a\":1, \"b\":2.0}", "a INT, b DOUBLE", + "UTF8_BINARY", Row(Row(1, 2.0))), + JsonToStructsTestCase("{\"A\":\"3\", \"B\":4}", "A STRING COLLATE UTF8_BINARY_LCASE, B INT", + "UTF8_BINARY_LCASE", Row(Row("3", 4))), + JsonToStructsTestCase("{\"x\":true, \"y\":null}", "x BOOLEAN, y VOID", + "UNICODE", Row(Row(true, null))), + JsonToStructsTestCase("{\"X\":null, \"Y\":false}", "X VOID, Y BOOLEAN", + "UNICODE_CI", Row(Row(null, false))) + ) + + // Supported collations + testCases.foreach(t => { + val query = + s""" + |SELECT from_json('${t.input}', '${t.schema}') + |""".stripMargin + // Result & data type + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + checkAnswer(testQuery, t.result) + val dataType = StructType.fromDDL(t.schema) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + }) + } + + test("Support StructsToJson json expression with collation") { + case class StructsToJsonTestCase( + struct: String, + collationName: String, + result: Row + ) + + val testCases = Seq( + StructsToJsonTestCase("named_struct('a', 1, 'b', 2)", + "UTF8_BINARY", Row("{\"a\":1,\"b\":2}")), + StructsToJsonTestCase("array(named_struct('a', 1, 'b', 2))", + "UTF8_BINARY_LCASE", Row("[{\"a\":1,\"b\":2}]")), + StructsToJsonTestCase("map('a', named_struct('b', 1))", + "UNICODE", Row("{\"a\":{\"b\":1}}")), + StructsToJsonTestCase("array(map('a', 1))", + "UNICODE_CI", Row("[{\"a\":1}]")) + ) + + // Supported collations + testCases.foreach(t => { + val query = + s""" + |SELECT to_json(${t.struct}) + |""".stripMargin + // Result & data type + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + checkAnswer(testQuery, t.result) + val dataType = StringType(t.collationName) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + }) + } + + test("Support LengthOfJsonArray json expression with collation") { + case class LengthOfJsonArrayTestCase( + input: String, + collationName: String, + result: Row + ) + + val testCases = Seq( + LengthOfJsonArrayTestCase("'[1,2,3,4]'", "UTF8_BINARY", Row(4)), + LengthOfJsonArrayTestCase("'[1,2,3,{\"f1\":1,\"f2\":[5,6]},4]'", "UTF8_BINARY_LCASE", Row(5)), + LengthOfJsonArrayTestCase("'[1,2'", "UNICODE", Row(null)), + LengthOfJsonArrayTestCase("'['", "UNICODE_CI", Row(null)) + ) + + // Supported collations + testCases.foreach(t => { + val query = + s""" + |SELECT json_array_length(${t.input}) + |""".stripMargin + // Result & data type + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + checkAnswer(testQuery, t.result) + assert(testQuery.schema.fields.head.dataType.sameType(IntegerType)) + } + }) + } + + test("Support JsonObjectKeys json expression with collation") { + case class JsonObjectKeysJsonArrayTestCase( + input: String, + collationName: String, + result: Row + ) + + val testCases = Seq( + JsonObjectKeysJsonArrayTestCase("{}", "UTF8_BINARY", + Row(Seq())), + JsonObjectKeysJsonArrayTestCase("{\"k\":", "UTF8_BINARY_LCASE", + Row(null)), + JsonObjectKeysJsonArrayTestCase("{\"k1\": \"v1\"}", "UNICODE", + Row(Seq("k1"))), + JsonObjectKeysJsonArrayTestCase("{\"k1\":1,\"k2\":{\"k3\":3, \"k4\":4}}", "UNICODE_CI", + Row(Seq("k1", "k2"))) + ) + + // Supported collations + testCases.foreach(t => { + val query = + s""" + |SELECT json_object_keys('${t.input}') + |""".stripMargin + // Result & data type + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + checkAnswer(testQuery, t.result) + val dataType = ArrayType(StringType(t.collationName)) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + }) + } + test("Support StringToMap expression with collation") { // Supported collations case class StringToMapTestCase[R](t: String, p: String, k: String, c: String, result: R) From 47afe77242abf639a1d6966ce60cfd170a9d7d20 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 8 May 2024 07:44:22 -0700 Subject: [PATCH 09/65] [SPARK-48184][PYTHON][CONNECT] Always set the seed of `Dataframe.sample` in Client side ### What changes were proposed in this pull request? Always set the seed of `Dataframe.sample` in Client side ### Why are the changes needed? Bug fix If the seed is not set in Client, it will be set in server side with a random int https://github.com/apache/spark/blob/c4df12cc884cddefcfcf8324b4d7b9349fb4f6a0/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala#L386 which cause inconsistent results in multiple executions In Spark Classic: ``` In [1]: df = spark.range(10000).sample(0.1) In [2]: [df.count() for i in range(10)] Out[2]: [1006, 1006, 1006, 1006, 1006, 1006, 1006, 1006, 1006, 1006] ``` In Spark Connect: before: ``` In [1]: df = spark.range(10000).sample(0.1) In [2]: [df.count() for i in range(10)] Out[2]: [969, 1005, 958, 996, 987, 1026, 991, 1020, 1012, 979] ``` after: ``` In [1]: df = spark.range(10000).sample(0.1) In [2]: [df.count() for i in range(10)] Out[2]: [1032, 1032, 1032, 1032, 1032, 1032, 1032, 1032, 1032, 1032] ``` ### Does this PR introduce _any_ user-facing change? yes, bug fix ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #46456 from zhengruifeng/py_connect_sample_seed. Authored-by: Ruifeng Zheng Signed-off-by: Dongjoon Hyun --- python/pyspark/sql/connect/dataframe.py | 2 +- python/pyspark/sql/tests/connect/test_connect_plan.py | 2 +- python/pyspark/sql/tests/test_dataframe.py | 5 +++++ 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index f9a209d2bcb3d..843c92a9b27d2 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -813,7 +813,7 @@ def sample( if withReplacement is None: withReplacement = False - seed = int(seed) if seed is not None else None + seed = int(seed) if seed is not None else random.randint(0, sys.maxsize) return DataFrame( plan.Sample( diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py b/python/pyspark/sql/tests/connect/test_connect_plan.py index 09c3171ee11fd..e8d04aeada740 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan.py @@ -443,7 +443,7 @@ def test_sample(self): self.assertEqual(plan.root.sample.lower_bound, 0.0) self.assertEqual(plan.root.sample.upper_bound, 0.3) self.assertEqual(plan.root.sample.with_replacement, False) - self.assertEqual(plan.root.sample.HasField("seed"), False) + self.assertEqual(plan.root.sample.HasField("seed"), True) self.assertEqual(plan.root.sample.deterministic_order, False) plan = ( diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 16dd0d2a3bf7c..f491b496ddae5 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -430,6 +430,11 @@ def test_sample(self): IllegalArgumentException, lambda: self.spark.range(1).sample(-1.0).count() ) + def test_sample_with_random_seed(self): + df = self.spark.range(10000).sample(0.1) + cnts = [df.count() for i in range(10)] + self.assertEqual(1, len(set(cnts))) + def test_toDF_with_string(self): df = self.spark.createDataFrame([("John", 30), ("Alice", 25), ("Bob", 28)]) data = [("John", 30), ("Alice", 25), ("Bob", 28)] From e0c406eaef36d95a106b6ce14086654ace6202af Mon Sep 17 00:00:00 2001 From: panbingkun Date: Wed, 8 May 2024 08:50:02 -0700 Subject: [PATCH 10/65] [SPARK-48198][BUILD] Upgrade jackson to 2.17.1 ### What changes were proposed in this pull request? The pr aims to upgrade `jackson` from `2.17.0` to `2.17.1`. ### Why are the changes needed? The full release notes: https://github.com/FasterXML/jackson/wiki/Jackson-Release-2.17.1 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46476 from panbingkun/SPARK-48198. Authored-by: panbingkun Signed-off-by: Dongjoon Hyun --- dev/deps/spark-deps-hadoop-3-hive-2.3 | 14 +++++++------- pom.xml | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 5d933e34e40ba..73d41e9eeb337 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -104,15 +104,15 @@ icu4j/72.1//icu4j-72.1.jar ini4j/0.5.4//ini4j-0.5.4.jar istack-commons-runtime/3.0.8//istack-commons-runtime-3.0.8.jar ivy/2.5.2//ivy-2.5.2.jar -jackson-annotations/2.17.0//jackson-annotations-2.17.0.jar +jackson-annotations/2.17.1//jackson-annotations-2.17.1.jar jackson-core-asl/1.9.13//jackson-core-asl-1.9.13.jar -jackson-core/2.17.0//jackson-core-2.17.0.jar -jackson-databind/2.17.0//jackson-databind-2.17.0.jar -jackson-dataformat-cbor/2.17.0//jackson-dataformat-cbor-2.17.0.jar -jackson-dataformat-yaml/2.17.0//jackson-dataformat-yaml-2.17.0.jar -jackson-datatype-jsr310/2.17.0//jackson-datatype-jsr310-2.17.0.jar +jackson-core/2.17.1//jackson-core-2.17.1.jar +jackson-databind/2.17.1//jackson-databind-2.17.1.jar +jackson-dataformat-cbor/2.17.1//jackson-dataformat-cbor-2.17.1.jar +jackson-dataformat-yaml/2.17.1//jackson-dataformat-yaml-2.17.1.jar +jackson-datatype-jsr310/2.17.1//jackson-datatype-jsr310-2.17.1.jar jackson-mapper-asl/1.9.13//jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.13/2.17.0//jackson-module-scala_2.13-2.17.0.jar +jackson-module-scala_2.13/2.17.1//jackson-module-scala_2.13-2.17.1.jar jakarta.annotation-api/2.0.0//jakarta.annotation-api-2.0.0.jar jakarta.inject-api/2.0.1//jakarta.inject-api-2.0.1.jar jakarta.servlet-api/5.0.0//jakarta.servlet-api-5.0.0.jar diff --git a/pom.xml b/pom.xml index c72482fd6a41f..c3ff5d101c224 100644 --- a/pom.xml +++ b/pom.xml @@ -183,8 +183,8 @@ true true 1.9.13 - 2.17.0 - 2.17.0 + 2.17.1 + 2.17.1 2.3.1 3.0.2 1.1.10.5 From 9d79ab42b127d1a12164cec260bfbd69f6da8b74 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 8 May 2024 09:40:03 -0700 Subject: [PATCH 11/65] [SPARK-48200][INFRA] Split `build_python.yml` into per-version cron jobs ### What changes were proposed in this pull request? This PR aims to split `build_python.yml` into per-version cron jobs. Technically, this includes a revert of SPARK-48149 and choose [the discussed alternative](https://github.com/apache/spark/pull/46407#discussion_r1591586209). - https://github.com/apache/spark/pull/46407 - https://github.com/apache/spark/pull/46454 ### Why are the changes needed? To recover Python CI successfully in ASF INFRA policy. - https://github.com/apache/spark/actions/workflows/build_python.yml ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manual review. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46477 from dongjoon-hyun/SPARK-48200. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- ...build_python.yml => build_python_3.10.yml} | 16 +------ .github/workflows/build_python_3.12.yml | 45 +++++++++++++++++++ .github/workflows/build_python_pypy3.9.yml | 45 +++++++++++++++++++ 3 files changed, 92 insertions(+), 14 deletions(-) rename .github/workflows/{build_python.yml => build_python_3.10.yml} (63%) create mode 100644 .github/workflows/build_python_3.12.yml create mode 100644 .github/workflows/build_python_pypy3.9.yml diff --git a/.github/workflows/build_python.yml b/.github/workflows/build_python_3.10.yml similarity index 63% rename from .github/workflows/build_python.yml rename to .github/workflows/build_python_3.10.yml index efa281d6a279c..5ae37fbc9120e 100644 --- a/.github/workflows/build_python.yml +++ b/.github/workflows/build_python_3.10.yml @@ -17,26 +17,14 @@ # under the License. # -# According to https://infra.apache.org/github-actions-policy.html, -# all workflows SHOULD have a job concurrency level less than or equal to 15. -# To do that, we run one python version per cron schedule -name: "Build / Python-only (master, PyPy 3.9/Python 3.10/Python 3.12)" +name: "Build / Python-only (master, Python 3.10)" on: schedule: - - cron: '0 15 * * *' - cron: '0 17 * * *' - - cron: '0 19 * * *' jobs: run-build: - strategy: - fail-fast: false - matrix: - include: - - pyversion: ${{ github.event.schedule == '0 15 * * *' && 'pypy3' }} - - pyversion: ${{ github.event.schedule == '0 17 * * *' && 'python3.10' }} - - pyversion: ${{ github.event.schedule == '0 19 * * *' && 'python3.12' }} permissions: packages: write name: Run @@ -48,7 +36,7 @@ jobs: hadoop: hadoop3 envs: >- { - "PYTHON_TO_TEST": "${{ matrix.pyversion }}" + "PYTHON_TO_TEST": "python3.10" } jobs: >- { diff --git a/.github/workflows/build_python_3.12.yml b/.github/workflows/build_python_3.12.yml new file mode 100644 index 0000000000000..e1fd45a7d8838 --- /dev/null +++ b/.github/workflows/build_python_3.12.yml @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +name: "Build / Python-only (master, Python 3.12)" + +on: + schedule: + - cron: '0 19 * * *' + +jobs: + run-build: + permissions: + packages: write + name: Run + uses: ./.github/workflows/build_and_test.yml + if: github.repository == 'apache/spark' + with: + java: 17 + branch: master + hadoop: hadoop3 + envs: >- + { + "PYTHON_TO_TEST": "python3.12" + } + jobs: >- + { + "pyspark": "true", + "pyspark-pandas": "true" + } diff --git a/.github/workflows/build_python_pypy3.9.yml b/.github/workflows/build_python_pypy3.9.yml new file mode 100644 index 0000000000000..e05071ef034a0 --- /dev/null +++ b/.github/workflows/build_python_pypy3.9.yml @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +name: "Build / Python-only (master, PyPy 3.9)" + +on: + schedule: + - cron: '0 15 * * *' + +jobs: + run-build: + permissions: + packages: write + name: Run + uses: ./.github/workflows/build_and_test.yml + if: github.repository == 'apache/spark' + with: + java: 17 + branch: master + hadoop: hadoop3 + envs: >- + { + "PYTHON_TO_TEST": "pypy3" + } + jobs: >- + { + "pyspark": "true", + "pyspark-pandas": "true" + } From 70e5d2aa7a992a6f4ff9c7d8e3752ce1d3d488f2 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 8 May 2024 10:47:52 -0700 Subject: [PATCH 12/65] [SPARK-48202][INFRA] Spin off `pyspark` tests from `build_branch35.yml` Daily CI ### What changes were proposed in this pull request? This PR aims to create `build_branch35_python.yml` in order to spin off `pyspark` tests from `build_branch35.yml` Daily CI. ### Why are the changes needed? Currently, `build_branch35.yml` creates more than 15 test pipelines concurrently which is beyond of ASF Infra policy. - https://github.com/apache/spark/actions/workflows/build_branch35.yml We had better offload this to `Python only Daily CI` like `master` branch's `Python Only` Daily CI. - https://github.com/apache/spark/actions/workflows/build_python_3.10.yml ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manual review. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46479 from dongjoon-hyun/SPARK-48202. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .github/workflows/build_branch35.yml | 1 - .github/workflows/build_branch35_python.yml | 45 +++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/build_branch35_python.yml diff --git a/.github/workflows/build_branch35.yml b/.github/workflows/build_branch35.yml index 55616c2f1f017..2ec080d5722c1 100644 --- a/.github/workflows/build_branch35.yml +++ b/.github/workflows/build_branch35.yml @@ -43,7 +43,6 @@ jobs: jobs: >- { "build": "true", - "pyspark": "true", "sparkr": "true", "tpcds-1g": "true", "docker-integration-tests": "true", diff --git a/.github/workflows/build_branch35_python.yml b/.github/workflows/build_branch35_python.yml new file mode 100644 index 0000000000000..1585534d33ba9 --- /dev/null +++ b/.github/workflows/build_branch35_python.yml @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +name: "Build / Python-only (branch-3.5)" + +on: + schedule: + - cron: '0 11 * * *' + +jobs: + run-build: + permissions: + packages: write + name: Run + uses: ./.github/workflows/build_and_test.yml + if: github.repository == 'apache/spark' + with: + java: 8 + branch: branch-3.5 + hadoop: hadoop3 + envs: >- + { + "PYTHON_TO_TEST": "" + } + jobs: >- + { + "pyspark": "true", + "pyspark-pandas": "true" + } From fbfcd402851ee604789b8ba72a1ee0e67ef5ebe4 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 8 May 2024 12:30:12 -0700 Subject: [PATCH 13/65] [SPARK-48203][INFRA] Spin off `pyspark` tests from `build_branch34.yml` Daily CI ### What changes were proposed in this pull request? This PR aims to create `build_branch34_python.yml` in order to spin off `pyspark` tests from `build_branch34.yml` Daily CI. ### Why are the changes needed? Currently, `build_branch34.yml` creates more than 15 test pipelines concurrently which is beyond of ASF Infra policy. - https://github.com/apache/spark/actions/workflows/build_branch35.yml We had better offload this to `Python only Daily CI` like `master` branch's `Python Only` Daily CI. - https://github.com/apache/spark/actions/workflows/build_python_3.10.yml ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manual review. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46480 from dongjoon-hyun/SPARK-48203. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .github/workflows/build_branch34.yml | 1 - .github/workflows/build_branch34_python.yml | 45 +++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/build_branch34_python.yml diff --git a/.github/workflows/build_branch34.yml b/.github/workflows/build_branch34.yml index 68887970d4d81..deb6c42407970 100644 --- a/.github/workflows/build_branch34.yml +++ b/.github/workflows/build_branch34.yml @@ -43,7 +43,6 @@ jobs: jobs: >- { "build": "true", - "pyspark": "true", "sparkr": "true", "tpcds-1g": "true", "docker-integration-tests": "true", diff --git a/.github/workflows/build_branch34_python.yml b/.github/workflows/build_branch34_python.yml new file mode 100644 index 0000000000000..c109ba2dc7922 --- /dev/null +++ b/.github/workflows/build_branch34_python.yml @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +name: "Build / Python-only (branch-3.4)" + +on: + schedule: + - cron: '0 9 * * *' + +jobs: + run-build: + permissions: + packages: write + name: Run + uses: ./.github/workflows/build_and_test.yml + if: github.repository == 'apache/spark' + with: + java: 8 + branch: branch-3.4 + hadoop: hadoop3 + envs: >- + { + "PYTHON_TO_TEST": "" + } + jobs: >- + { + "pyspark": "true", + "pyspark-pandas": "true" + } From 21548a8cc5c527d4416a276a852f967b4410bd4b Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Wed, 8 May 2024 15:44:02 -0400 Subject: [PATCH 14/65] [SPARK-47545][CONNECT] Dataset `observe` support for the Scala client ### What changes were proposed in this pull request? This PR adds support for `Dataset.observe` to the Spark Connect Scala client. Note that the support here does not include listener support as it runs on the serve side. This PR includes a small refactoring to the `Observation` helper class. We extracted methods that are not bound to the SparkSession to `spark-api`, and added two subclasses on both `spark-core` and `spark-jvm-client`. ### Why are the changes needed? Before this PR, the `DF.observe` method is only supported in the Python client. ### Does this PR introduce _any_ user-facing change? Yes. The user can now issue `DF.observe(name, metrics...)` or `DF.observe(observationObject, metrics...)` to get stats of columns of a dataframe. ### How was this patch tested? Added new e2e tests. ### Was this patch authored or co-authored using generative AI tooling? Nope. Closes #45701 from xupefei/scala-observe. Authored-by: Paddy Xu Signed-off-by: Herman van Hovell --- .../scala/org/apache/spark/sql/Dataset.scala | 63 +++++- .../org/apache/spark/sql/Observation.scala | 46 +++++ .../org/apache/spark/sql/SparkSession.scala | 31 ++- .../apache/spark/sql/ClientE2ETestSuite.scala | 43 ++++ .../CheckConnectJvmClientCompatibility.scala | 3 - .../main/protobuf/spark/connect/base.proto | 1 + .../sql/connect/client/SparkResult.scala | 44 +++- .../common/LiteralValueProtoConverter.scala | 2 +- .../execution/ExecuteThreadRunner.scala | 1 + .../execution/SparkConnectPlanExecution.scala | 12 +- python/pyspark/sql/connect/proto/base_pb2.py | 188 +++++++++--------- python/pyspark/sql/connect/proto/base_pb2.pyi | 5 +- .../apache/spark/sql/ObservationBase.scala | 113 +++++++++++ .../org/apache/spark/sql/Observation.scala | 62 +----- 14 files changed, 448 insertions(+), 166 deletions(-) create mode 100644 connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Observation.scala create mode 100644 sql/api/src/main/scala/org/apache/spark/sql/ObservationBase.scala diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 9a42afebf8f2b..37f770319b695 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3337,8 +3337,69 @@ class Dataset[T] private[sql] ( } } + /** + * Define (named) metrics to observe on the Dataset. This method returns an 'observed' Dataset + * that returns the same result as the input, with the following guarantees:
  • It will + * compute the defined aggregates (metrics) on all the data that is flowing through the Dataset + * at that point.
  • It will report the value of the defined aggregate columns as soon as + * we reach a completion point. A completion point is currently defined as the end of a + * query.
Please note that continuous execution is currently not supported. + * + * The metrics columns must either contain a literal (e.g. lit(42)), or should contain one or + * more aggregate functions (e.g. sum(a) or sum(a + b) + avg(c) - lit(1)). Expressions that + * contain references to the input Dataset's columns must always be wrapped in an aggregate + * function. + * + * A user can retrieve the metrics by calling + * `org.apache.spark.sql.Dataset.collectResult().getObservedMetrics`. + * + * {{{ + * // Observe row count (rows) and highest id (maxid) in the Dataset while writing it + * val observed_ds = ds.observe("my_metrics", count(lit(1)).as("rows"), max($"id").as("maxid")) + * observed_ds.write.parquet("ds.parquet") + * val metrics = observed_ds.collectResult().getObservedMetrics + * }}} + * + * @group typedrel + * @since 4.0.0 + */ + @scala.annotation.varargs def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = { - throw new UnsupportedOperationException("observe is not implemented.") + sparkSession.newDataset(agnosticEncoder) { builder => + builder.getCollectMetricsBuilder + .setInput(plan.getRoot) + .setName(name) + .addAllMetrics((expr +: exprs).map(_.expr).asJava) + } + } + + /** + * Observe (named) metrics through an `org.apache.spark.sql.Observation` instance. This is + * equivalent to calling `observe(String, Column, Column*)` but does not require to collect all + * results before returning the metrics - the metrics are filled during iterating the results, + * as soon as they are available. This method does not support streaming datasets. + * + * A user can retrieve the metrics by accessing `org.apache.spark.sql.Observation.get`. + * + * {{{ + * // Observe row count (rows) and highest id (maxid) in the Dataset while writing it + * val observation = Observation("my_metrics") + * val observed_ds = ds.observe(observation, count(lit(1)).as("rows"), max($"id").as("maxid")) + * observed_ds.write.parquet("ds.parquet") + * val metrics = observation.get + * }}} + * + * @throws IllegalArgumentException + * If this is a streaming Dataset (this.isStreaming == true) + * + * @group typedrel + * @since 4.0.0 + */ + @scala.annotation.varargs + def observe(observation: Observation, expr: Column, exprs: Column*): Dataset[T] = { + val df = observe(observation.name, expr, exprs: _*) + sparkSession.registerObservation(df.getPlanId.get, observation) + df } def checkpoint(): Dataset[T] = { diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Observation.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Observation.scala new file mode 100644 index 0000000000000..75629b6000f91 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Observation.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.UUID + +class Observation(name: String) extends ObservationBase(name) { + + /** + * Create an Observation instance without providing a name. This generates a random name. + */ + def this() = this(UUID.randomUUID().toString) +} + +/** + * (Scala-specific) Create instances of Observation via Scala `apply`. + * @since 4.0.0 + */ +object Observation { + + /** + * Observation constructor for creating an anonymous observation. + */ + def apply(): Observation = new Observation() + + /** + * Observation constructor for creating a named observation. + */ + def apply(name: String): Observation = new Observation(name) + +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 22bb62803fac5..1188fba60a2fe 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.io.Closeable import java.net.URI +import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.TimeUnit._ import java.util.concurrent.atomic.{AtomicLong, AtomicReference} @@ -36,7 +37,7 @@ import org.apache.spark.sql.catalog.Catalog import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedLongEncoder, UnboundRowEncoder} -import org.apache.spark.sql.connect.client.{ClassFinder, SparkConnectClient, SparkResult} +import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult} import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration import org.apache.spark.sql.connect.client.arrow.ArrowSerializer import org.apache.spark.sql.functions.lit @@ -80,6 +81,8 @@ class SparkSession private[sql] ( client.analyze(proto.AnalyzePlanRequest.AnalyzeCase.SPARK_VERSION).getSparkVersion.getVersion } + private[sql] val observationRegistry = new ConcurrentHashMap[Long, Observation]() + /** * Runtime configuration interface for Spark. * @@ -532,8 +535,12 @@ class SparkSession private[sql] ( private[sql] def execute[T](plan: proto.Plan, encoder: AgnosticEncoder[T]): SparkResult[T] = { val value = client.execute(plan) - val result = new SparkResult(value, allocator, encoder, timeZoneId) - result + new SparkResult( + value, + allocator, + encoder, + timeZoneId, + Some(setMetricsAndUnregisterObservation)) } private[sql] def execute(f: proto.Relation.Builder => Unit): Unit = { @@ -554,6 +561,9 @@ class SparkSession private[sql] ( client.execute(plan).filter(!_.hasExecutionProgress).toSeq } + private[sql] def execute(plan: proto.Plan): CloseableIterator[ExecutePlanResponse] = + client.execute(plan) + private[sql] def registerUdf(udf: proto.CommonInlineUserDefinedFunction): Unit = { val command = proto.Command.newBuilder().setRegisterFunction(udf).build() execute(command) @@ -779,6 +789,21 @@ class SparkSession private[sql] ( * Set to false to prevent client.releaseSession on close() (testing only) */ private[sql] var releaseSessionOnClose = true + + private[sql] def registerObservation(planId: Long, observation: Observation): Unit = { + if (observationRegistry.putIfAbsent(planId, observation) != null) { + throw new IllegalArgumentException("An Observation can be used with a Dataset only once") + } + } + + private[sql] def setMetricsAndUnregisterObservation( + planId: Long, + metrics: Map[String, Any]): Unit = { + val observationOrNull = observationRegistry.remove(planId) + if (observationOrNull != null) { + observationOrNull.setMetricsAndNotify(Some(metrics)) + } + } } // The minimal builder needed to create a spark session. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index a0729adb89609..73a2f6d4f88e1 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -22,6 +22,8 @@ import java.time.DateTimeException import java.util.Properties import scala.collection.mutable +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration.DurationInt import scala.jdk.CollectionConverters._ import org.apache.commons.io.FileUtils @@ -41,6 +43,7 @@ import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.test.{IntegrationTestUtils, RemoteSparkSession, SQLHelper} import org.apache.spark.sql.test.SparkConnectServerUtils.port import org.apache.spark.sql.types._ +import org.apache.spark.util.SparkThreadUtils class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateMethodTester { @@ -1511,6 +1514,46 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM (0 until 5).foreach(i => assert(row.get(i * 2) === row.get(i * 2 + 1))) } } + + test("Observable metrics") { + val df = spark.range(99).withColumn("extra", col("id") - 1) + val ob1 = new Observation("ob1") + val observedDf = df.observe(ob1, min("id"), avg("id"), max("id")) + val observedObservedDf = observedDf.observe("ob2", min("extra"), avg("extra"), max("extra")) + + val ob1Schema = new StructType() + .add("min(id)", LongType) + .add("avg(id)", DoubleType) + .add("max(id)", LongType) + val ob2Schema = new StructType() + .add("min(extra)", LongType) + .add("avg(extra)", DoubleType) + .add("max(extra)", LongType) + val ob1Metrics = Map("ob1" -> new GenericRowWithSchema(Array(0, 49, 98), ob1Schema)) + val ob2Metrics = Map("ob2" -> new GenericRowWithSchema(Array(-1, 48, 97), ob2Schema)) + + assert(df.collectResult().getObservedMetrics === Map.empty) + assert(observedDf.collectResult().getObservedMetrics === ob1Metrics) + assert(observedObservedDf.collectResult().getObservedMetrics === ob1Metrics ++ ob2Metrics) + } + + test("Observation.get is blocked until the query is finished") { + val df = spark.range(99).withColumn("extra", col("id") - 1) + val observation = new Observation("ob1") + val observedDf = df.observe(observation, min("id"), avg("id"), max("id")) + + // Start a new thread to get the observation + val future = Future(observation.get)(ExecutionContext.global) + // make sure the thread is blocked right now + val e = intercept[java.util.concurrent.TimeoutException] { + SparkThreadUtils.awaitResult(future, 2.seconds) + } + assert(e.getMessage.contains("Future timed out")) + observedDf.collect() + // make sure the thread is unblocked after the query is finished + val metrics = SparkThreadUtils.awaitResult(future, 2.seconds) + assert(metrics === Map("min(id)" -> 0, "avg(id)" -> 49, "max(id)" -> 98)) + } } private[sql] case class ClassData(a: String, b: Int) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index c89dba03ed699..7be5e2ecd1725 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -196,9 +196,6 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.COL_POS_KEY"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.DATASET_ID_KEY"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.curId"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.observe"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Observation"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Observation$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener$"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.queryExecution"), diff --git a/connector/connect/common/src/main/protobuf/spark/connect/base.proto b/connector/connect/common/src/main/protobuf/spark/connect/base.proto index 49a33d3419b6f..77dda277602ab 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto @@ -434,6 +434,7 @@ message ExecutePlanResponse { string name = 1; repeated Expression.Literal values = 2; repeated string keys = 3; + int64 plan_id = 4; } message ResultComplete { diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala index 93d1075aea025..0905ee76c3f34 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala @@ -27,10 +27,13 @@ import org.apache.arrow.vector.ipc.message.{ArrowMessage, ArrowRecordBatch} import org.apache.arrow.vector.types.pojo import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.ExecutePlanResponse.ObservedMetrics +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder, UnboundRowEncoder} +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.connect.client.arrow.{AbstractMessageIterator, ArrowDeserializingIterator, ConcatenatingArrowStreamReader, MessageIterator} -import org.apache.spark.sql.connect.common.DataTypeProtoConverter +import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, LiteralValueProtoConverter} import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.ArrowUtils @@ -38,7 +41,8 @@ private[sql] class SparkResult[T]( responses: CloseableIterator[proto.ExecutePlanResponse], allocator: BufferAllocator, encoder: AgnosticEncoder[T], - timeZoneId: String) + timeZoneId: String, + setObservationMetricsOpt: Option[(Long, Map[String, Any]) => Unit] = None) extends AutoCloseable { self => case class StageInfo( @@ -79,6 +83,7 @@ private[sql] class SparkResult[T]( private[this] var arrowSchema: pojo.Schema = _ private[this] var nextResultIndex: Int = 0 private val resultMap = mutable.Map.empty[Int, (Long, Seq[ArrowMessage])] + private val observedMetrics = mutable.Map.empty[String, Row] private val cleanable = SparkResult.cleaner.register(this, new SparkResultCloseable(resultMap, responses)) @@ -117,6 +122,9 @@ private[sql] class SparkResult[T]( while (!stop && responses.hasNext) { val response = responses.next() + // Collect metrics for this response + observedMetrics ++= processObservedMetrics(response.getObservedMetricsList) + // Save and validate operationId if (opId == null) { opId = response.getOperationId @@ -198,6 +206,29 @@ private[sql] class SparkResult[T]( nonEmpty } + private def processObservedMetrics( + metrics: java.util.List[ObservedMetrics]): Iterable[(String, Row)] = { + metrics.asScala.map { metric => + assert(metric.getKeysCount == metric.getValuesCount) + var schema = new StructType() + val keys = mutable.ListBuffer.empty[String] + val values = mutable.ListBuffer.empty[Any] + (0 until metric.getKeysCount).map { i => + val key = metric.getKeys(i) + val value = LiteralValueProtoConverter.toCatalystValue(metric.getValues(i)) + schema = schema.add(key, LiteralValueProtoConverter.toDataType(value.getClass)) + keys += key + values += value + } + // If the metrics is registered by an Observation object, attach them and unblock any + // blocked thread. + setObservationMetricsOpt.foreach { setObservationMetrics => + setObservationMetrics(metric.getPlanId, keys.zip(values).toMap) + } + metric.getName -> new GenericRowWithSchema(values.toArray, schema) + } + } + /** * Returns the number of elements in the result. */ @@ -248,6 +279,15 @@ private[sql] class SparkResult[T]( result } + /** + * Returns all observed metrics in the result. + */ + def getObservedMetrics: Map[String, Row] = { + // We need to process all responses to get all metrics. + processResponses() + observedMetrics.toMap + } + /** * Returns an iterator over the contents of the result. */ diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala index ce42cc797bf38..1f3496fa89847 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala @@ -204,7 +204,7 @@ object LiteralValueProtoConverter { def toLiteralProto(literal: Any, dataType: DataType): proto.Expression.Literal = toLiteralProtoBuilder(literal, dataType).build() - private def toDataType(clz: Class[_]): DataType = clz match { + private[sql] def toDataType(clz: Class[_]): DataType = clz match { // primitive types case JShort.TYPE => ShortType case JInteger.TYPE => IntegerType diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala index 0a6d12cbb1918..4ef4f632204b3 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala @@ -220,6 +220,7 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends .createObservedMetricsResponse( executeHolder.sessionHolder.sessionId, executeHolder.sessionHolder.serverSessionId, + executeHolder.request.getPlan.getRoot.getCommon.getPlanId, observedMetrics ++ accumulatedInPython)) } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala index 4f2b8c945127b..660951f229849 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala @@ -264,8 +264,14 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) name -> values } if (observedMetrics.nonEmpty) { - Some(SparkConnectPlanExecution - .createObservedMetricsResponse(sessionId, sessionHolder.serverSessionId, observedMetrics)) + val planId = executeHolder.request.getPlan.getRoot.getCommon.getPlanId + Some( + SparkConnectPlanExecution + .createObservedMetricsResponse( + sessionId, + sessionHolder.serverSessionId, + planId, + observedMetrics)) } else None } } @@ -274,11 +280,13 @@ object SparkConnectPlanExecution { def createObservedMetricsResponse( sessionId: String, serverSessionId: String, + planId: Long, metrics: Map[String, Seq[(Option[String], Any)]]): ExecutePlanResponse = { val observedMetrics = metrics.map { case (name, values) => val metrics = ExecutePlanResponse.ObservedMetrics .newBuilder() .setName(name) + .setPlanId(planId) values.foreach { case (key, value) => metrics.addValues(toLiteralProto(value)) key.foreach(metrics.addKeys) diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py index 2a30ffe60a9f2..a39396db4ff1d 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.py +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -37,7 +37,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xf8\x13\n\x12\x41nalyzePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x11 \x01(\tH\x01R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x02R\nclientType\x88\x01\x01\x12\x42\n\x06schema\x18\x04 \x01(\x0b\x32(.spark.connect.AnalyzePlanRequest.SchemaH\x00R\x06schema\x12\x45\n\x07\x65xplain\x18\x05 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.ExplainH\x00R\x07\x65xplain\x12O\n\x0btree_string\x18\x06 \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.TreeStringH\x00R\ntreeString\x12\x46\n\x08is_local\x18\x07 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.IsLocalH\x00R\x07isLocal\x12R\n\x0cis_streaming\x18\x08 \x01(\x0b\x32-.spark.connect.AnalyzePlanRequest.IsStreamingH\x00R\x0bisStreaming\x12O\n\x0binput_files\x18\t \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.InputFilesH\x00R\ninputFiles\x12U\n\rspark_version\x18\n \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SparkVersionH\x00R\x0csparkVersion\x12I\n\tddl_parse\x18\x0b \x01(\x0b\x32*.spark.connect.AnalyzePlanRequest.DDLParseH\x00R\x08\x64\x64lParse\x12X\n\x0esame_semantics\x18\x0c \x01(\x0b\x32/.spark.connect.AnalyzePlanRequest.SameSemanticsH\x00R\rsameSemantics\x12U\n\rsemantic_hash\x18\r \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SemanticHashH\x00R\x0csemanticHash\x12\x45\n\x07persist\x18\x0e \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.PersistH\x00R\x07persist\x12K\n\tunpersist\x18\x0f \x01(\x0b\x32+.spark.connect.AnalyzePlanRequest.UnpersistH\x00R\tunpersist\x12_\n\x11get_storage_level\x18\x10 \x01(\x0b\x32\x31.spark.connect.AnalyzePlanRequest.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x31\n\x06Schema\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\xbb\x02\n\x07\x45xplain\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12X\n\x0c\x65xplain_mode\x18\x02 \x01(\x0e\x32\x35.spark.connect.AnalyzePlanRequest.Explain.ExplainModeR\x0b\x65xplainMode"\xac\x01\n\x0b\x45xplainMode\x12\x1c\n\x18\x45XPLAIN_MODE_UNSPECIFIED\x10\x00\x12\x17\n\x13\x45XPLAIN_MODE_SIMPLE\x10\x01\x12\x19\n\x15\x45XPLAIN_MODE_EXTENDED\x10\x02\x12\x18\n\x14\x45XPLAIN_MODE_CODEGEN\x10\x03\x12\x15\n\x11\x45XPLAIN_MODE_COST\x10\x04\x12\x1a\n\x16\x45XPLAIN_MODE_FORMATTED\x10\x05\x1aZ\n\nTreeString\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12\x19\n\x05level\x18\x02 \x01(\x05H\x00R\x05level\x88\x01\x01\x42\x08\n\x06_level\x1a\x32\n\x07IsLocal\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x36\n\x0bIsStreaming\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x35\n\nInputFiles\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x0e\n\x0cSparkVersion\x1a)\n\x08\x44\x44LParse\x12\x1d\n\nddl_string\x18\x01 \x01(\tR\tddlString\x1ay\n\rSameSemantics\x12\x34\n\x0btarget_plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\ntargetPlan\x12\x32\n\nother_plan\x18\x02 \x01(\x0b\x32\x13.spark.connect.PlanR\totherPlan\x1a\x37\n\x0cSemanticHash\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x97\x01\n\x07Persist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x45\n\rstorage_level\x18\x02 \x01(\x0b\x32\x1b.spark.connect.StorageLevelH\x00R\x0cstorageLevel\x88\x01\x01\x42\x10\n\x0e_storage_level\x1an\n\tUnpersist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x1f\n\x08\x62locking\x18\x02 \x01(\x08H\x00R\x08\x62locking\x88\x01\x01\x42\x0b\n\t_blocking\x1a\x46\n\x0fGetStorageLevel\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relationB\t\n\x07\x61nalyzeB)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_type"\xce\r\n\x13\x41nalyzePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x0f \x01(\tR\x13serverSideSessionId\x12\x43\n\x06schema\x18\x02 \x01(\x0b\x32).spark.connect.AnalyzePlanResponse.SchemaH\x00R\x06schema\x12\x46\n\x07\x65xplain\x18\x03 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.ExplainH\x00R\x07\x65xplain\x12P\n\x0btree_string\x18\x04 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.TreeStringH\x00R\ntreeString\x12G\n\x08is_local\x18\x05 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.IsLocalH\x00R\x07isLocal\x12S\n\x0cis_streaming\x18\x06 \x01(\x0b\x32..spark.connect.AnalyzePlanResponse.IsStreamingH\x00R\x0bisStreaming\x12P\n\x0binput_files\x18\x07 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.InputFilesH\x00R\ninputFiles\x12V\n\rspark_version\x18\x08 \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SparkVersionH\x00R\x0csparkVersion\x12J\n\tddl_parse\x18\t \x01(\x0b\x32+.spark.connect.AnalyzePlanResponse.DDLParseH\x00R\x08\x64\x64lParse\x12Y\n\x0esame_semantics\x18\n \x01(\x0b\x32\x30.spark.connect.AnalyzePlanResponse.SameSemanticsH\x00R\rsameSemantics\x12V\n\rsemantic_hash\x18\x0b \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SemanticHashH\x00R\x0csemanticHash\x12\x46\n\x07persist\x18\x0c \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.PersistH\x00R\x07persist\x12L\n\tunpersist\x18\r \x01(\x0b\x32,.spark.connect.AnalyzePlanResponse.UnpersistH\x00R\tunpersist\x12`\n\x11get_storage_level\x18\x0e \x01(\x0b\x32\x32.spark.connect.AnalyzePlanResponse.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x39\n\x06Schema\x12/\n\x06schema\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1a\x30\n\x07\x45xplain\x12%\n\x0e\x65xplain_string\x18\x01 \x01(\tR\rexplainString\x1a-\n\nTreeString\x12\x1f\n\x0btree_string\x18\x01 \x01(\tR\ntreeString\x1a$\n\x07IsLocal\x12\x19\n\x08is_local\x18\x01 \x01(\x08R\x07isLocal\x1a\x30\n\x0bIsStreaming\x12!\n\x0cis_streaming\x18\x01 \x01(\x08R\x0bisStreaming\x1a"\n\nInputFiles\x12\x14\n\x05\x66iles\x18\x01 \x03(\tR\x05\x66iles\x1a(\n\x0cSparkVersion\x12\x18\n\x07version\x18\x01 \x01(\tR\x07version\x1a;\n\x08\x44\x44LParse\x12/\n\x06parsed\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06parsed\x1a\'\n\rSameSemantics\x12\x16\n\x06result\x18\x01 \x01(\x08R\x06result\x1a&\n\x0cSemanticHash\x12\x16\n\x06result\x18\x01 \x01(\x05R\x06result\x1a\t\n\x07Persist\x1a\x0b\n\tUnpersist\x1aS\n\x0fGetStorageLevel\x12@\n\rstorage_level\x18\x01 \x01(\x0b\x32\x1b.spark.connect.StorageLevelR\x0cstorageLevelB\x08\n\x06result"\xa3\x05\n\x12\x45xecutePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x08 \x01(\tH\x00R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12&\n\x0coperation_id\x18\x06 \x01(\tH\x01R\x0boperationId\x88\x01\x01\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x02R\nclientType\x88\x01\x01\x12X\n\x0frequest_options\x18\x05 \x03(\x0b\x32/.spark.connect.ExecutePlanRequest.RequestOptionR\x0erequestOptions\x12\x12\n\x04tags\x18\x07 \x03(\tR\x04tags\x1a\xa5\x01\n\rRequestOption\x12K\n\x10reattach_options\x18\x01 \x01(\x0b\x32\x1e.spark.connect.ReattachOptionsH\x00R\x0freattachOptions\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x10\n\x0erequest_optionB)\n\'_client_observed_server_side_session_idB\x0f\n\r_operation_idB\x0e\n\x0c_client_type"\xe6\x15\n\x13\x45xecutePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x0f \x01(\tR\x13serverSideSessionId\x12!\n\x0coperation_id\x18\x0c \x01(\tR\x0boperationId\x12\x1f\n\x0bresponse_id\x18\r \x01(\tR\nresponseId\x12P\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32-.spark.connect.ExecutePlanResponse.ArrowBatchH\x00R\narrowBatch\x12\x63\n\x12sql_command_result\x18\x05 \x01(\x0b\x32\x33.spark.connect.ExecutePlanResponse.SqlCommandResultH\x00R\x10sqlCommandResult\x12~\n#write_stream_operation_start_result\x18\x08 \x01(\x0b\x32..spark.connect.WriteStreamOperationStartResultH\x00R\x1fwriteStreamOperationStartResult\x12q\n\x1estreaming_query_command_result\x18\t \x01(\x0b\x32*.spark.connect.StreamingQueryCommandResultH\x00R\x1bstreamingQueryCommandResult\x12k\n\x1cget_resources_command_result\x18\n \x01(\x0b\x32(.spark.connect.GetResourcesCommandResultH\x00R\x19getResourcesCommandResult\x12\x87\x01\n&streaming_query_manager_command_result\x18\x0b \x01(\x0b\x32\x31.spark.connect.StreamingQueryManagerCommandResultH\x00R"streamingQueryManagerCommandResult\x12\x87\x01\n&streaming_query_listener_events_result\x18\x10 \x01(\x0b\x32\x31.spark.connect.StreamingQueryListenerEventsResultH\x00R"streamingQueryListenerEventsResult\x12\\\n\x0fresult_complete\x18\x0e \x01(\x0b\x32\x31.spark.connect.ExecutePlanResponse.ResultCompleteH\x00R\x0eresultComplete\x12\x87\x01\n&create_resource_profile_command_result\x18\x11 \x01(\x0b\x32\x31.spark.connect.CreateResourceProfileCommandResultH\x00R"createResourceProfileCommandResult\x12\x65\n\x12\x65xecution_progress\x18\x12 \x01(\x0b\x32\x34.spark.connect.ExecutePlanResponse.ExecutionProgressH\x00R\x11\x65xecutionProgress\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x44\n\x07metrics\x18\x04 \x01(\x0b\x32*.spark.connect.ExecutePlanResponse.MetricsR\x07metrics\x12]\n\x10observed_metrics\x18\x06 \x03(\x0b\x32\x32.spark.connect.ExecutePlanResponse.ObservedMetricsR\x0fobservedMetrics\x12/\n\x06schema\x18\x07 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1aG\n\x10SqlCommandResult\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x1av\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x12&\n\x0cstart_offset\x18\x03 \x01(\x03H\x00R\x0bstartOffset\x88\x01\x01\x42\x0f\n\r_start_offset\x1a\x85\x04\n\x07Metrics\x12Q\n\x07metrics\x18\x01 \x03(\x0b\x32\x37.spark.connect.ExecutePlanResponse.Metrics.MetricObjectR\x07metrics\x1a\xcc\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12z\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32M.spark.connect.ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1a{\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ExecutePlanResponse.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricType\x1at\n\x0fObservedMetrics\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x12\x12\n\x04keys\x18\x03 \x03(\tR\x04keys\x1a\x10\n\x0eResultComplete\x1a\xcd\x02\n\x11\x45xecutionProgress\x12V\n\x06stages\x18\x01 \x03(\x0b\x32>.spark.connect.ExecutePlanResponse.ExecutionProgress.StageInfoR\x06stages\x12,\n\x12num_inflight_tasks\x18\x02 \x01(\x03R\x10numInflightTasks\x1a\xb1\x01\n\tStageInfo\x12\x19\n\x08stage_id\x18\x01 \x01(\x03R\x07stageId\x12\x1b\n\tnum_tasks\x18\x02 \x01(\x03R\x08numTasks\x12.\n\x13num_completed_tasks\x18\x03 \x01(\x03R\x11numCompletedTasks\x12(\n\x10input_bytes_read\x18\x04 \x01(\x03R\x0einputBytesRead\x12\x12\n\x04\x64one\x18\x05 \x01(\x08R\x04\x64oneB\x0f\n\rresponse_type"A\n\x08KeyValue\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x19\n\x05value\x18\x02 \x01(\tH\x00R\x05value\x88\x01\x01\x42\x08\n\x06_value"\x87\t\n\rConfigRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x08 \x01(\tH\x00R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x44\n\toperation\x18\x03 \x01(\x0b\x32&.spark.connect.ConfigRequest.OperationR\toperation\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x1a\xf2\x03\n\tOperation\x12\x34\n\x03set\x18\x01 \x01(\x0b\x32 .spark.connect.ConfigRequest.SetH\x00R\x03set\x12\x34\n\x03get\x18\x02 \x01(\x0b\x32 .spark.connect.ConfigRequest.GetH\x00R\x03get\x12W\n\x10get_with_default\x18\x03 \x01(\x0b\x32+.spark.connect.ConfigRequest.GetWithDefaultH\x00R\x0egetWithDefault\x12G\n\nget_option\x18\x04 \x01(\x0b\x32&.spark.connect.ConfigRequest.GetOptionH\x00R\tgetOption\x12>\n\x07get_all\x18\x05 \x01(\x0b\x32#.spark.connect.ConfigRequest.GetAllH\x00R\x06getAll\x12:\n\x05unset\x18\x06 \x01(\x0b\x32".spark.connect.ConfigRequest.UnsetH\x00R\x05unset\x12P\n\ris_modifiable\x18\x07 \x01(\x0b\x32).spark.connect.ConfigRequest.IsModifiableH\x00R\x0cisModifiableB\t\n\x07op_type\x1a\x34\n\x03Set\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x19\n\x03Get\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a?\n\x0eGetWithDefault\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x1f\n\tGetOption\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a\x30\n\x06GetAll\x12\x1b\n\x06prefix\x18\x01 \x01(\tH\x00R\x06prefix\x88\x01\x01\x42\t\n\x07_prefix\x1a\x1b\n\x05Unset\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a"\n\x0cIsModifiable\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keysB)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_type"\xaf\x01\n\x0e\x43onfigResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x04 \x01(\tR\x13serverSideSessionId\x12-\n\x05pairs\x18\x02 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x12\x1a\n\x08warnings\x18\x03 \x03(\tR\x08warnings"\xea\x07\n\x13\x41\x64\x64\x41rtifactsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12V\n&client_observed_server_side_session_id\x18\x07 \x01(\tH\x01R!clientObservedServerSideSessionId\x88\x01\x01\x12$\n\x0b\x63lient_type\x18\x06 \x01(\tH\x02R\nclientType\x88\x01\x01\x12@\n\x05\x62\x61tch\x18\x03 \x01(\x0b\x32(.spark.connect.AddArtifactsRequest.BatchH\x00R\x05\x62\x61tch\x12Z\n\x0b\x62\x65gin_chunk\x18\x04 \x01(\x0b\x32\x37.spark.connect.AddArtifactsRequest.BeginChunkedArtifactH\x00R\nbeginChunk\x12H\n\x05\x63hunk\x18\x05 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkH\x00R\x05\x63hunk\x1a\x35\n\rArtifactChunk\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta\x12\x10\n\x03\x63rc\x18\x02 \x01(\x03R\x03\x63rc\x1ao\n\x13SingleChunkArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x44\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x04\x64\x61ta\x1a]\n\x05\x42\x61tch\x12T\n\tartifacts\x18\x01 \x03(\x0b\x32\x36.spark.connect.AddArtifactsRequest.SingleChunkArtifactR\tartifacts\x1a\xc1\x01\n\x14\x42\x65ginChunkedArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1f\n\x0btotal_bytes\x18\x02 \x01(\x03R\ntotalBytes\x12\x1d\n\nnum_chunks\x18\x03 \x01(\x03R\tnumChunks\x12U\n\rinitial_chunk\x18\x04 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x0cinitialChunkB\t\n\x07payloadB)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_type"\x90\x02\n\x14\x41\x64\x64\x41rtifactsResponse\x12\x1d\n\nsession_id\x18\x02 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x03 \x01(\tR\x13serverSideSessionId\x12Q\n\tartifacts\x18\x01 \x03(\x0b\x32\x33.spark.connect.AddArtifactsResponse.ArtifactSummaryR\tartifacts\x1aQ\n\x0f\x41rtifactSummary\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12*\n\x11is_crc_successful\x18\x02 \x01(\x08R\x0fisCrcSuccessful"\xc6\x02\n\x17\x41rtifactStatusesRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x05 \x01(\tH\x00R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12\x14\n\x05names\x18\x04 \x03(\tR\x05namesB)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_type"\xe0\x02\n\x18\x41rtifactStatusesResponse\x12\x1d\n\nsession_id\x18\x02 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x03 \x01(\tR\x13serverSideSessionId\x12Q\n\x08statuses\x18\x01 \x03(\x0b\x32\x35.spark.connect.ArtifactStatusesResponse.StatusesEntryR\x08statuses\x1as\n\rStatusesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ArtifactStatusesResponse.ArtifactStatusR\x05value:\x02\x38\x01\x1a(\n\x0e\x41rtifactStatus\x12\x16\n\x06\x65xists\x18\x01 \x01(\x08R\x06\x65xists"\xdb\x04\n\x10InterruptRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x07 \x01(\tH\x01R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x02R\nclientType\x88\x01\x01\x12T\n\x0einterrupt_type\x18\x04 \x01(\x0e\x32-.spark.connect.InterruptRequest.InterruptTypeR\rinterruptType\x12%\n\roperation_tag\x18\x05 \x01(\tH\x00R\x0coperationTag\x12#\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId"\x80\x01\n\rInterruptType\x12\x1e\n\x1aINTERRUPT_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12INTERRUPT_TYPE_ALL\x10\x01\x12\x16\n\x12INTERRUPT_TYPE_TAG\x10\x02\x12\x1f\n\x1bINTERRUPT_TYPE_OPERATION_ID\x10\x03\x42\x0b\n\tinterruptB)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_type"\x90\x01\n\x11InterruptResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x03 \x01(\tR\x13serverSideSessionId\x12\'\n\x0finterrupted_ids\x18\x02 \x03(\tR\x0einterruptedIds"5\n\x0fReattachOptions\x12"\n\x0creattachable\x18\x01 \x01(\x08R\x0creattachable"\x96\x03\n\x16ReattachExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x06 \x01(\tH\x00R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x12-\n\x10last_response_id\x18\x05 \x01(\tH\x02R\x0elastResponseId\x88\x01\x01\x42)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_typeB\x13\n\x11_last_response_id"\xc9\x04\n\x15ReleaseExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x07 \x01(\tH\x01R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x02R\nclientType\x88\x01\x01\x12R\n\x0brelease_all\x18\x05 \x01(\x0b\x32/.spark.connect.ReleaseExecuteRequest.ReleaseAllH\x00R\nreleaseAll\x12X\n\rrelease_until\x18\x06 \x01(\x0b\x32\x31.spark.connect.ReleaseExecuteRequest.ReleaseUntilH\x00R\x0creleaseUntil\x1a\x0c\n\nReleaseAll\x1a/\n\x0cReleaseUntil\x12\x1f\n\x0bresponse_id\x18\x01 \x01(\tR\nresponseIdB\t\n\x07releaseB)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_type"\xa5\x01\n\x16ReleaseExecuteResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x03 \x01(\tR\x13serverSideSessionId\x12&\n\x0coperation_id\x18\x02 \x01(\tH\x00R\x0boperationId\x88\x01\x01\x42\x0f\n\r_operation_id"\xab\x01\n\x15ReleaseSessionRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"l\n\x16ReleaseSessionResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x02 \x01(\tR\x13serverSideSessionId"\xcc\x02\n\x18\x46\x65tchErrorDetailsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x05 \x01(\tH\x00R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x19\n\x08\x65rror_id\x18\x03 \x01(\tR\x07\x65rrorId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x42)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_type"\x93\x0c\n\x19\x46\x65tchErrorDetailsResponse\x12\x33\n\x16server_side_session_id\x18\x03 \x01(\tR\x13serverSideSessionId\x12\x1d\n\nsession_id\x18\x04 \x01(\tR\tsessionId\x12)\n\x0eroot_error_idx\x18\x01 \x01(\x05H\x00R\x0crootErrorIdx\x88\x01\x01\x12\x46\n\x06\x65rrors\x18\x02 \x03(\x0b\x32..spark.connect.FetchErrorDetailsResponse.ErrorR\x06\x65rrors\x1a\xae\x01\n\x11StackTraceElement\x12\'\n\x0f\x64\x65\x63laring_class\x18\x01 \x01(\tR\x0e\x64\x65\x63laringClass\x12\x1f\n\x0bmethod_name\x18\x02 \x01(\tR\nmethodName\x12 \n\tfile_name\x18\x03 \x01(\tH\x00R\x08\x66ileName\x88\x01\x01\x12\x1f\n\x0bline_number\x18\x04 \x01(\x05R\nlineNumberB\x0c\n\n_file_name\x1a\xf0\x02\n\x0cQueryContext\x12\x64\n\x0c\x63ontext_type\x18\n \x01(\x0e\x32\x41.spark.connect.FetchErrorDetailsResponse.QueryContext.ContextTypeR\x0b\x63ontextType\x12\x1f\n\x0bobject_type\x18\x01 \x01(\tR\nobjectType\x12\x1f\n\x0bobject_name\x18\x02 \x01(\tR\nobjectName\x12\x1f\n\x0bstart_index\x18\x03 \x01(\x05R\nstartIndex\x12\x1d\n\nstop_index\x18\x04 \x01(\x05R\tstopIndex\x12\x1a\n\x08\x66ragment\x18\x05 \x01(\tR\x08\x66ragment\x12\x1b\n\tcall_site\x18\x06 \x01(\tR\x08\x63\x61llSite\x12\x18\n\x07summary\x18\x07 \x01(\tR\x07summary"%\n\x0b\x43ontextType\x12\x07\n\x03SQL\x10\x00\x12\r\n\tDATAFRAME\x10\x01\x1a\x99\x03\n\x0eSparkThrowable\x12$\n\x0b\x65rror_class\x18\x01 \x01(\tH\x00R\nerrorClass\x88\x01\x01\x12}\n\x12message_parameters\x18\x02 \x03(\x0b\x32N.spark.connect.FetchErrorDetailsResponse.SparkThrowable.MessageParametersEntryR\x11messageParameters\x12\\\n\x0equery_contexts\x18\x03 \x03(\x0b\x32\x35.spark.connect.FetchErrorDetailsResponse.QueryContextR\rqueryContexts\x12 \n\tsql_state\x18\x04 \x01(\tH\x01R\x08sqlState\x88\x01\x01\x1a\x44\n\x16MessageParametersEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x0e\n\x0c_error_classB\x0c\n\n_sql_state\x1a\xdb\x02\n\x05\x45rror\x12\x30\n\x14\x65rror_type_hierarchy\x18\x01 \x03(\tR\x12\x65rrorTypeHierarchy\x12\x18\n\x07message\x18\x02 \x01(\tR\x07message\x12[\n\x0bstack_trace\x18\x03 \x03(\x0b\x32:.spark.connect.FetchErrorDetailsResponse.StackTraceElementR\nstackTrace\x12 \n\tcause_idx\x18\x04 \x01(\x05H\x00R\x08\x63\x61useIdx\x88\x01\x01\x12\x65\n\x0fspark_throwable\x18\x05 \x01(\x0b\x32\x37.spark.connect.FetchErrorDetailsResponse.SparkThrowableH\x01R\x0esparkThrowable\x88\x01\x01\x42\x0c\n\n_cause_idxB\x12\n\x10_spark_throwableB\x11\n\x0f_root_error_idx2\xb2\x07\n\x13SparkConnectService\x12X\n\x0b\x45xecutePlan\x12!.spark.connect.ExecutePlanRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12V\n\x0b\x41nalyzePlan\x12!.spark.connect.AnalyzePlanRequest\x1a".spark.connect.AnalyzePlanResponse"\x00\x12G\n\x06\x43onfig\x12\x1c.spark.connect.ConfigRequest\x1a\x1d.spark.connect.ConfigResponse"\x00\x12[\n\x0c\x41\x64\x64\x41rtifacts\x12".spark.connect.AddArtifactsRequest\x1a#.spark.connect.AddArtifactsResponse"\x00(\x01\x12\x63\n\x0e\x41rtifactStatus\x12&.spark.connect.ArtifactStatusesRequest\x1a\'.spark.connect.ArtifactStatusesResponse"\x00\x12P\n\tInterrupt\x12\x1f.spark.connect.InterruptRequest\x1a .spark.connect.InterruptResponse"\x00\x12`\n\x0fReattachExecute\x12%.spark.connect.ReattachExecuteRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12_\n\x0eReleaseExecute\x12$.spark.connect.ReleaseExecuteRequest\x1a%.spark.connect.ReleaseExecuteResponse"\x00\x12_\n\x0eReleaseSession\x12$.spark.connect.ReleaseSessionRequest\x1a%.spark.connect.ReleaseSessionResponse"\x00\x12h\n\x11\x46\x65tchErrorDetails\x12\'.spark.connect.FetchErrorDetailsRequest\x1a(.spark.connect.FetchErrorDetailsResponse"\x00\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xf8\x13\n\x12\x41nalyzePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x11 \x01(\tH\x01R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x02R\nclientType\x88\x01\x01\x12\x42\n\x06schema\x18\x04 \x01(\x0b\x32(.spark.connect.AnalyzePlanRequest.SchemaH\x00R\x06schema\x12\x45\n\x07\x65xplain\x18\x05 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.ExplainH\x00R\x07\x65xplain\x12O\n\x0btree_string\x18\x06 \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.TreeStringH\x00R\ntreeString\x12\x46\n\x08is_local\x18\x07 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.IsLocalH\x00R\x07isLocal\x12R\n\x0cis_streaming\x18\x08 \x01(\x0b\x32-.spark.connect.AnalyzePlanRequest.IsStreamingH\x00R\x0bisStreaming\x12O\n\x0binput_files\x18\t \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.InputFilesH\x00R\ninputFiles\x12U\n\rspark_version\x18\n \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SparkVersionH\x00R\x0csparkVersion\x12I\n\tddl_parse\x18\x0b \x01(\x0b\x32*.spark.connect.AnalyzePlanRequest.DDLParseH\x00R\x08\x64\x64lParse\x12X\n\x0esame_semantics\x18\x0c \x01(\x0b\x32/.spark.connect.AnalyzePlanRequest.SameSemanticsH\x00R\rsameSemantics\x12U\n\rsemantic_hash\x18\r \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SemanticHashH\x00R\x0csemanticHash\x12\x45\n\x07persist\x18\x0e \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.PersistH\x00R\x07persist\x12K\n\tunpersist\x18\x0f \x01(\x0b\x32+.spark.connect.AnalyzePlanRequest.UnpersistH\x00R\tunpersist\x12_\n\x11get_storage_level\x18\x10 \x01(\x0b\x32\x31.spark.connect.AnalyzePlanRequest.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x31\n\x06Schema\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\xbb\x02\n\x07\x45xplain\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12X\n\x0c\x65xplain_mode\x18\x02 \x01(\x0e\x32\x35.spark.connect.AnalyzePlanRequest.Explain.ExplainModeR\x0b\x65xplainMode"\xac\x01\n\x0b\x45xplainMode\x12\x1c\n\x18\x45XPLAIN_MODE_UNSPECIFIED\x10\x00\x12\x17\n\x13\x45XPLAIN_MODE_SIMPLE\x10\x01\x12\x19\n\x15\x45XPLAIN_MODE_EXTENDED\x10\x02\x12\x18\n\x14\x45XPLAIN_MODE_CODEGEN\x10\x03\x12\x15\n\x11\x45XPLAIN_MODE_COST\x10\x04\x12\x1a\n\x16\x45XPLAIN_MODE_FORMATTED\x10\x05\x1aZ\n\nTreeString\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12\x19\n\x05level\x18\x02 \x01(\x05H\x00R\x05level\x88\x01\x01\x42\x08\n\x06_level\x1a\x32\n\x07IsLocal\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x36\n\x0bIsStreaming\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x35\n\nInputFiles\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x0e\n\x0cSparkVersion\x1a)\n\x08\x44\x44LParse\x12\x1d\n\nddl_string\x18\x01 \x01(\tR\tddlString\x1ay\n\rSameSemantics\x12\x34\n\x0btarget_plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\ntargetPlan\x12\x32\n\nother_plan\x18\x02 \x01(\x0b\x32\x13.spark.connect.PlanR\totherPlan\x1a\x37\n\x0cSemanticHash\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x97\x01\n\x07Persist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x45\n\rstorage_level\x18\x02 \x01(\x0b\x32\x1b.spark.connect.StorageLevelH\x00R\x0cstorageLevel\x88\x01\x01\x42\x10\n\x0e_storage_level\x1an\n\tUnpersist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x1f\n\x08\x62locking\x18\x02 \x01(\x08H\x00R\x08\x62locking\x88\x01\x01\x42\x0b\n\t_blocking\x1a\x46\n\x0fGetStorageLevel\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relationB\t\n\x07\x61nalyzeB)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_type"\xce\r\n\x13\x41nalyzePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x0f \x01(\tR\x13serverSideSessionId\x12\x43\n\x06schema\x18\x02 \x01(\x0b\x32).spark.connect.AnalyzePlanResponse.SchemaH\x00R\x06schema\x12\x46\n\x07\x65xplain\x18\x03 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.ExplainH\x00R\x07\x65xplain\x12P\n\x0btree_string\x18\x04 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.TreeStringH\x00R\ntreeString\x12G\n\x08is_local\x18\x05 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.IsLocalH\x00R\x07isLocal\x12S\n\x0cis_streaming\x18\x06 \x01(\x0b\x32..spark.connect.AnalyzePlanResponse.IsStreamingH\x00R\x0bisStreaming\x12P\n\x0binput_files\x18\x07 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.InputFilesH\x00R\ninputFiles\x12V\n\rspark_version\x18\x08 \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SparkVersionH\x00R\x0csparkVersion\x12J\n\tddl_parse\x18\t \x01(\x0b\x32+.spark.connect.AnalyzePlanResponse.DDLParseH\x00R\x08\x64\x64lParse\x12Y\n\x0esame_semantics\x18\n \x01(\x0b\x32\x30.spark.connect.AnalyzePlanResponse.SameSemanticsH\x00R\rsameSemantics\x12V\n\rsemantic_hash\x18\x0b \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SemanticHashH\x00R\x0csemanticHash\x12\x46\n\x07persist\x18\x0c \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.PersistH\x00R\x07persist\x12L\n\tunpersist\x18\r \x01(\x0b\x32,.spark.connect.AnalyzePlanResponse.UnpersistH\x00R\tunpersist\x12`\n\x11get_storage_level\x18\x0e \x01(\x0b\x32\x32.spark.connect.AnalyzePlanResponse.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x39\n\x06Schema\x12/\n\x06schema\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1a\x30\n\x07\x45xplain\x12%\n\x0e\x65xplain_string\x18\x01 \x01(\tR\rexplainString\x1a-\n\nTreeString\x12\x1f\n\x0btree_string\x18\x01 \x01(\tR\ntreeString\x1a$\n\x07IsLocal\x12\x19\n\x08is_local\x18\x01 \x01(\x08R\x07isLocal\x1a\x30\n\x0bIsStreaming\x12!\n\x0cis_streaming\x18\x01 \x01(\x08R\x0bisStreaming\x1a"\n\nInputFiles\x12\x14\n\x05\x66iles\x18\x01 \x03(\tR\x05\x66iles\x1a(\n\x0cSparkVersion\x12\x18\n\x07version\x18\x01 \x01(\tR\x07version\x1a;\n\x08\x44\x44LParse\x12/\n\x06parsed\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06parsed\x1a\'\n\rSameSemantics\x12\x16\n\x06result\x18\x01 \x01(\x08R\x06result\x1a&\n\x0cSemanticHash\x12\x16\n\x06result\x18\x01 \x01(\x05R\x06result\x1a\t\n\x07Persist\x1a\x0b\n\tUnpersist\x1aS\n\x0fGetStorageLevel\x12@\n\rstorage_level\x18\x01 \x01(\x0b\x32\x1b.spark.connect.StorageLevelR\x0cstorageLevelB\x08\n\x06result"\xa3\x05\n\x12\x45xecutePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x08 \x01(\tH\x00R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12&\n\x0coperation_id\x18\x06 \x01(\tH\x01R\x0boperationId\x88\x01\x01\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x02R\nclientType\x88\x01\x01\x12X\n\x0frequest_options\x18\x05 \x03(\x0b\x32/.spark.connect.ExecutePlanRequest.RequestOptionR\x0erequestOptions\x12\x12\n\x04tags\x18\x07 \x03(\tR\x04tags\x1a\xa5\x01\n\rRequestOption\x12K\n\x10reattach_options\x18\x01 \x01(\x0b\x32\x1e.spark.connect.ReattachOptionsH\x00R\x0freattachOptions\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x10\n\x0erequest_optionB)\n\'_client_observed_server_side_session_idB\x0f\n\r_operation_idB\x0e\n\x0c_client_type"\x80\x16\n\x13\x45xecutePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x0f \x01(\tR\x13serverSideSessionId\x12!\n\x0coperation_id\x18\x0c \x01(\tR\x0boperationId\x12\x1f\n\x0bresponse_id\x18\r \x01(\tR\nresponseId\x12P\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32-.spark.connect.ExecutePlanResponse.ArrowBatchH\x00R\narrowBatch\x12\x63\n\x12sql_command_result\x18\x05 \x01(\x0b\x32\x33.spark.connect.ExecutePlanResponse.SqlCommandResultH\x00R\x10sqlCommandResult\x12~\n#write_stream_operation_start_result\x18\x08 \x01(\x0b\x32..spark.connect.WriteStreamOperationStartResultH\x00R\x1fwriteStreamOperationStartResult\x12q\n\x1estreaming_query_command_result\x18\t \x01(\x0b\x32*.spark.connect.StreamingQueryCommandResultH\x00R\x1bstreamingQueryCommandResult\x12k\n\x1cget_resources_command_result\x18\n \x01(\x0b\x32(.spark.connect.GetResourcesCommandResultH\x00R\x19getResourcesCommandResult\x12\x87\x01\n&streaming_query_manager_command_result\x18\x0b \x01(\x0b\x32\x31.spark.connect.StreamingQueryManagerCommandResultH\x00R"streamingQueryManagerCommandResult\x12\x87\x01\n&streaming_query_listener_events_result\x18\x10 \x01(\x0b\x32\x31.spark.connect.StreamingQueryListenerEventsResultH\x00R"streamingQueryListenerEventsResult\x12\\\n\x0fresult_complete\x18\x0e \x01(\x0b\x32\x31.spark.connect.ExecutePlanResponse.ResultCompleteH\x00R\x0eresultComplete\x12\x87\x01\n&create_resource_profile_command_result\x18\x11 \x01(\x0b\x32\x31.spark.connect.CreateResourceProfileCommandResultH\x00R"createResourceProfileCommandResult\x12\x65\n\x12\x65xecution_progress\x18\x12 \x01(\x0b\x32\x34.spark.connect.ExecutePlanResponse.ExecutionProgressH\x00R\x11\x65xecutionProgress\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x44\n\x07metrics\x18\x04 \x01(\x0b\x32*.spark.connect.ExecutePlanResponse.MetricsR\x07metrics\x12]\n\x10observed_metrics\x18\x06 \x03(\x0b\x32\x32.spark.connect.ExecutePlanResponse.ObservedMetricsR\x0fobservedMetrics\x12/\n\x06schema\x18\x07 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1aG\n\x10SqlCommandResult\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x1av\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x12&\n\x0cstart_offset\x18\x03 \x01(\x03H\x00R\x0bstartOffset\x88\x01\x01\x42\x0f\n\r_start_offset\x1a\x85\x04\n\x07Metrics\x12Q\n\x07metrics\x18\x01 \x03(\x0b\x32\x37.spark.connect.ExecutePlanResponse.Metrics.MetricObjectR\x07metrics\x1a\xcc\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12z\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32M.spark.connect.ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1a{\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ExecutePlanResponse.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricType\x1a\x8d\x01\n\x0fObservedMetrics\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x12\x12\n\x04keys\x18\x03 \x03(\tR\x04keys\x12\x17\n\x07plan_id\x18\x04 \x01(\x03R\x06planId\x1a\x10\n\x0eResultComplete\x1a\xcd\x02\n\x11\x45xecutionProgress\x12V\n\x06stages\x18\x01 \x03(\x0b\x32>.spark.connect.ExecutePlanResponse.ExecutionProgress.StageInfoR\x06stages\x12,\n\x12num_inflight_tasks\x18\x02 \x01(\x03R\x10numInflightTasks\x1a\xb1\x01\n\tStageInfo\x12\x19\n\x08stage_id\x18\x01 \x01(\x03R\x07stageId\x12\x1b\n\tnum_tasks\x18\x02 \x01(\x03R\x08numTasks\x12.\n\x13num_completed_tasks\x18\x03 \x01(\x03R\x11numCompletedTasks\x12(\n\x10input_bytes_read\x18\x04 \x01(\x03R\x0einputBytesRead\x12\x12\n\x04\x64one\x18\x05 \x01(\x08R\x04\x64oneB\x0f\n\rresponse_type"A\n\x08KeyValue\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x19\n\x05value\x18\x02 \x01(\tH\x00R\x05value\x88\x01\x01\x42\x08\n\x06_value"\x87\t\n\rConfigRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x08 \x01(\tH\x00R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x44\n\toperation\x18\x03 \x01(\x0b\x32&.spark.connect.ConfigRequest.OperationR\toperation\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x1a\xf2\x03\n\tOperation\x12\x34\n\x03set\x18\x01 \x01(\x0b\x32 .spark.connect.ConfigRequest.SetH\x00R\x03set\x12\x34\n\x03get\x18\x02 \x01(\x0b\x32 .spark.connect.ConfigRequest.GetH\x00R\x03get\x12W\n\x10get_with_default\x18\x03 \x01(\x0b\x32+.spark.connect.ConfigRequest.GetWithDefaultH\x00R\x0egetWithDefault\x12G\n\nget_option\x18\x04 \x01(\x0b\x32&.spark.connect.ConfigRequest.GetOptionH\x00R\tgetOption\x12>\n\x07get_all\x18\x05 \x01(\x0b\x32#.spark.connect.ConfigRequest.GetAllH\x00R\x06getAll\x12:\n\x05unset\x18\x06 \x01(\x0b\x32".spark.connect.ConfigRequest.UnsetH\x00R\x05unset\x12P\n\ris_modifiable\x18\x07 \x01(\x0b\x32).spark.connect.ConfigRequest.IsModifiableH\x00R\x0cisModifiableB\t\n\x07op_type\x1a\x34\n\x03Set\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x19\n\x03Get\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a?\n\x0eGetWithDefault\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x1f\n\tGetOption\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a\x30\n\x06GetAll\x12\x1b\n\x06prefix\x18\x01 \x01(\tH\x00R\x06prefix\x88\x01\x01\x42\t\n\x07_prefix\x1a\x1b\n\x05Unset\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a"\n\x0cIsModifiable\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keysB)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_type"\xaf\x01\n\x0e\x43onfigResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x04 \x01(\tR\x13serverSideSessionId\x12-\n\x05pairs\x18\x02 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x12\x1a\n\x08warnings\x18\x03 \x03(\tR\x08warnings"\xea\x07\n\x13\x41\x64\x64\x41rtifactsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12V\n&client_observed_server_side_session_id\x18\x07 \x01(\tH\x01R!clientObservedServerSideSessionId\x88\x01\x01\x12$\n\x0b\x63lient_type\x18\x06 \x01(\tH\x02R\nclientType\x88\x01\x01\x12@\n\x05\x62\x61tch\x18\x03 \x01(\x0b\x32(.spark.connect.AddArtifactsRequest.BatchH\x00R\x05\x62\x61tch\x12Z\n\x0b\x62\x65gin_chunk\x18\x04 \x01(\x0b\x32\x37.spark.connect.AddArtifactsRequest.BeginChunkedArtifactH\x00R\nbeginChunk\x12H\n\x05\x63hunk\x18\x05 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkH\x00R\x05\x63hunk\x1a\x35\n\rArtifactChunk\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta\x12\x10\n\x03\x63rc\x18\x02 \x01(\x03R\x03\x63rc\x1ao\n\x13SingleChunkArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x44\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x04\x64\x61ta\x1a]\n\x05\x42\x61tch\x12T\n\tartifacts\x18\x01 \x03(\x0b\x32\x36.spark.connect.AddArtifactsRequest.SingleChunkArtifactR\tartifacts\x1a\xc1\x01\n\x14\x42\x65ginChunkedArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1f\n\x0btotal_bytes\x18\x02 \x01(\x03R\ntotalBytes\x12\x1d\n\nnum_chunks\x18\x03 \x01(\x03R\tnumChunks\x12U\n\rinitial_chunk\x18\x04 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x0cinitialChunkB\t\n\x07payloadB)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_type"\x90\x02\n\x14\x41\x64\x64\x41rtifactsResponse\x12\x1d\n\nsession_id\x18\x02 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x03 \x01(\tR\x13serverSideSessionId\x12Q\n\tartifacts\x18\x01 \x03(\x0b\x32\x33.spark.connect.AddArtifactsResponse.ArtifactSummaryR\tartifacts\x1aQ\n\x0f\x41rtifactSummary\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12*\n\x11is_crc_successful\x18\x02 \x01(\x08R\x0fisCrcSuccessful"\xc6\x02\n\x17\x41rtifactStatusesRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x05 \x01(\tH\x00R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12\x14\n\x05names\x18\x04 \x03(\tR\x05namesB)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_type"\xe0\x02\n\x18\x41rtifactStatusesResponse\x12\x1d\n\nsession_id\x18\x02 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x03 \x01(\tR\x13serverSideSessionId\x12Q\n\x08statuses\x18\x01 \x03(\x0b\x32\x35.spark.connect.ArtifactStatusesResponse.StatusesEntryR\x08statuses\x1as\n\rStatusesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ArtifactStatusesResponse.ArtifactStatusR\x05value:\x02\x38\x01\x1a(\n\x0e\x41rtifactStatus\x12\x16\n\x06\x65xists\x18\x01 \x01(\x08R\x06\x65xists"\xdb\x04\n\x10InterruptRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x07 \x01(\tH\x01R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x02R\nclientType\x88\x01\x01\x12T\n\x0einterrupt_type\x18\x04 \x01(\x0e\x32-.spark.connect.InterruptRequest.InterruptTypeR\rinterruptType\x12%\n\roperation_tag\x18\x05 \x01(\tH\x00R\x0coperationTag\x12#\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId"\x80\x01\n\rInterruptType\x12\x1e\n\x1aINTERRUPT_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12INTERRUPT_TYPE_ALL\x10\x01\x12\x16\n\x12INTERRUPT_TYPE_TAG\x10\x02\x12\x1f\n\x1bINTERRUPT_TYPE_OPERATION_ID\x10\x03\x42\x0b\n\tinterruptB)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_type"\x90\x01\n\x11InterruptResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x03 \x01(\tR\x13serverSideSessionId\x12\'\n\x0finterrupted_ids\x18\x02 \x03(\tR\x0einterruptedIds"5\n\x0fReattachOptions\x12"\n\x0creattachable\x18\x01 \x01(\x08R\x0creattachable"\x96\x03\n\x16ReattachExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x06 \x01(\tH\x00R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x12-\n\x10last_response_id\x18\x05 \x01(\tH\x02R\x0elastResponseId\x88\x01\x01\x42)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_typeB\x13\n\x11_last_response_id"\xc9\x04\n\x15ReleaseExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x07 \x01(\tH\x01R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x02R\nclientType\x88\x01\x01\x12R\n\x0brelease_all\x18\x05 \x01(\x0b\x32/.spark.connect.ReleaseExecuteRequest.ReleaseAllH\x00R\nreleaseAll\x12X\n\rrelease_until\x18\x06 \x01(\x0b\x32\x31.spark.connect.ReleaseExecuteRequest.ReleaseUntilH\x00R\x0creleaseUntil\x1a\x0c\n\nReleaseAll\x1a/\n\x0cReleaseUntil\x12\x1f\n\x0bresponse_id\x18\x01 \x01(\tR\nresponseIdB\t\n\x07releaseB)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_type"\xa5\x01\n\x16ReleaseExecuteResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x03 \x01(\tR\x13serverSideSessionId\x12&\n\x0coperation_id\x18\x02 \x01(\tH\x00R\x0boperationId\x88\x01\x01\x42\x0f\n\r_operation_id"\xab\x01\n\x15ReleaseSessionRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"l\n\x16ReleaseSessionResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x02 \x01(\tR\x13serverSideSessionId"\xcc\x02\n\x18\x46\x65tchErrorDetailsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x05 \x01(\tH\x00R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x19\n\x08\x65rror_id\x18\x03 \x01(\tR\x07\x65rrorId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x42)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_type"\x93\x0c\n\x19\x46\x65tchErrorDetailsResponse\x12\x33\n\x16server_side_session_id\x18\x03 \x01(\tR\x13serverSideSessionId\x12\x1d\n\nsession_id\x18\x04 \x01(\tR\tsessionId\x12)\n\x0eroot_error_idx\x18\x01 \x01(\x05H\x00R\x0crootErrorIdx\x88\x01\x01\x12\x46\n\x06\x65rrors\x18\x02 \x03(\x0b\x32..spark.connect.FetchErrorDetailsResponse.ErrorR\x06\x65rrors\x1a\xae\x01\n\x11StackTraceElement\x12\'\n\x0f\x64\x65\x63laring_class\x18\x01 \x01(\tR\x0e\x64\x65\x63laringClass\x12\x1f\n\x0bmethod_name\x18\x02 \x01(\tR\nmethodName\x12 \n\tfile_name\x18\x03 \x01(\tH\x00R\x08\x66ileName\x88\x01\x01\x12\x1f\n\x0bline_number\x18\x04 \x01(\x05R\nlineNumberB\x0c\n\n_file_name\x1a\xf0\x02\n\x0cQueryContext\x12\x64\n\x0c\x63ontext_type\x18\n \x01(\x0e\x32\x41.spark.connect.FetchErrorDetailsResponse.QueryContext.ContextTypeR\x0b\x63ontextType\x12\x1f\n\x0bobject_type\x18\x01 \x01(\tR\nobjectType\x12\x1f\n\x0bobject_name\x18\x02 \x01(\tR\nobjectName\x12\x1f\n\x0bstart_index\x18\x03 \x01(\x05R\nstartIndex\x12\x1d\n\nstop_index\x18\x04 \x01(\x05R\tstopIndex\x12\x1a\n\x08\x66ragment\x18\x05 \x01(\tR\x08\x66ragment\x12\x1b\n\tcall_site\x18\x06 \x01(\tR\x08\x63\x61llSite\x12\x18\n\x07summary\x18\x07 \x01(\tR\x07summary"%\n\x0b\x43ontextType\x12\x07\n\x03SQL\x10\x00\x12\r\n\tDATAFRAME\x10\x01\x1a\x99\x03\n\x0eSparkThrowable\x12$\n\x0b\x65rror_class\x18\x01 \x01(\tH\x00R\nerrorClass\x88\x01\x01\x12}\n\x12message_parameters\x18\x02 \x03(\x0b\x32N.spark.connect.FetchErrorDetailsResponse.SparkThrowable.MessageParametersEntryR\x11messageParameters\x12\\\n\x0equery_contexts\x18\x03 \x03(\x0b\x32\x35.spark.connect.FetchErrorDetailsResponse.QueryContextR\rqueryContexts\x12 \n\tsql_state\x18\x04 \x01(\tH\x01R\x08sqlState\x88\x01\x01\x1a\x44\n\x16MessageParametersEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x0e\n\x0c_error_classB\x0c\n\n_sql_state\x1a\xdb\x02\n\x05\x45rror\x12\x30\n\x14\x65rror_type_hierarchy\x18\x01 \x03(\tR\x12\x65rrorTypeHierarchy\x12\x18\n\x07message\x18\x02 \x01(\tR\x07message\x12[\n\x0bstack_trace\x18\x03 \x03(\x0b\x32:.spark.connect.FetchErrorDetailsResponse.StackTraceElementR\nstackTrace\x12 \n\tcause_idx\x18\x04 \x01(\x05H\x00R\x08\x63\x61useIdx\x88\x01\x01\x12\x65\n\x0fspark_throwable\x18\x05 \x01(\x0b\x32\x37.spark.connect.FetchErrorDetailsResponse.SparkThrowableH\x01R\x0esparkThrowable\x88\x01\x01\x42\x0c\n\n_cause_idxB\x12\n\x10_spark_throwableB\x11\n\x0f_root_error_idx2\xb2\x07\n\x13SparkConnectService\x12X\n\x0b\x45xecutePlan\x12!.spark.connect.ExecutePlanRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12V\n\x0b\x41nalyzePlan\x12!.spark.connect.AnalyzePlanRequest\x1a".spark.connect.AnalyzePlanResponse"\x00\x12G\n\x06\x43onfig\x12\x1c.spark.connect.ConfigRequest\x1a\x1d.spark.connect.ConfigResponse"\x00\x12[\n\x0c\x41\x64\x64\x41rtifacts\x12".spark.connect.AddArtifactsRequest\x1a#.spark.connect.AddArtifactsResponse"\x00(\x01\x12\x63\n\x0e\x41rtifactStatus\x12&.spark.connect.ArtifactStatusesRequest\x1a\'.spark.connect.ArtifactStatusesResponse"\x00\x12P\n\tInterrupt\x12\x1f.spark.connect.InterruptRequest\x1a .spark.connect.InterruptResponse"\x00\x12`\n\x0fReattachExecute\x12%.spark.connect.ReattachExecuteRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12_\n\x0eReleaseExecute\x12$.spark.connect.ReleaseExecuteRequest\x1a%.spark.connect.ReleaseExecuteResponse"\x00\x12_\n\x0eReleaseSession\x12$.spark.connect.ReleaseSessionRequest\x1a%.spark.connect.ReleaseSessionResponse"\x00\x12h\n\x11\x46\x65tchErrorDetails\x12\'.spark.connect.FetchErrorDetailsRequest\x1a(.spark.connect.FetchErrorDetailsResponse"\x00\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -120,7 +120,7 @@ _EXECUTEPLANREQUEST_REQUESTOPTION._serialized_start = 5196 _EXECUTEPLANREQUEST_REQUESTOPTION._serialized_end = 5361 _EXECUTEPLANRESPONSE._serialized_start = 5440 - _EXECUTEPLANRESPONSE._serialized_end = 8230 + _EXECUTEPLANRESPONSE._serialized_end = 8256 _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_start = 7030 _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_end = 7101 _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 7103 @@ -133,96 +133,96 @@ _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 7651 _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 7653 _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 7741 - _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_start = 7743 - _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_end = 7859 - _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_start = 7861 - _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_end = 7877 - _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS._serialized_start = 7880 - _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS._serialized_end = 8213 - _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS_STAGEINFO._serialized_start = 8036 - _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS_STAGEINFO._serialized_end = 8213 - _KEYVALUE._serialized_start = 8232 - _KEYVALUE._serialized_end = 8297 - _CONFIGREQUEST._serialized_start = 8300 - _CONFIGREQUEST._serialized_end = 9459 - _CONFIGREQUEST_OPERATION._serialized_start = 8608 - _CONFIGREQUEST_OPERATION._serialized_end = 9106 - _CONFIGREQUEST_SET._serialized_start = 9108 - _CONFIGREQUEST_SET._serialized_end = 9160 - _CONFIGREQUEST_GET._serialized_start = 9162 - _CONFIGREQUEST_GET._serialized_end = 9187 - _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 9189 - _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 9252 - _CONFIGREQUEST_GETOPTION._serialized_start = 9254 - _CONFIGREQUEST_GETOPTION._serialized_end = 9285 - _CONFIGREQUEST_GETALL._serialized_start = 9287 - _CONFIGREQUEST_GETALL._serialized_end = 9335 - _CONFIGREQUEST_UNSET._serialized_start = 9337 - _CONFIGREQUEST_UNSET._serialized_end = 9364 - _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 9366 - _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 9400 - _CONFIGRESPONSE._serialized_start = 9462 - _CONFIGRESPONSE._serialized_end = 9637 - _ADDARTIFACTSREQUEST._serialized_start = 9640 - _ADDARTIFACTSREQUEST._serialized_end = 10642 - _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 10115 - _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 10168 - _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 10170 - _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 10281 - _ADDARTIFACTSREQUEST_BATCH._serialized_start = 10283 - _ADDARTIFACTSREQUEST_BATCH._serialized_end = 10376 - _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 10379 - _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 10572 - _ADDARTIFACTSRESPONSE._serialized_start = 10645 - _ADDARTIFACTSRESPONSE._serialized_end = 10917 - _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 10836 - _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 10917 - _ARTIFACTSTATUSESREQUEST._serialized_start = 10920 - _ARTIFACTSTATUSESREQUEST._serialized_end = 11246 - _ARTIFACTSTATUSESRESPONSE._serialized_start = 11249 - _ARTIFACTSTATUSESRESPONSE._serialized_end = 11601 - _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_start = 11444 - _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_end = 11559 - _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_start = 11561 - _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_end = 11601 - _INTERRUPTREQUEST._serialized_start = 11604 - _INTERRUPTREQUEST._serialized_end = 12207 - _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_start = 12007 - _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_end = 12135 - _INTERRUPTRESPONSE._serialized_start = 12210 - _INTERRUPTRESPONSE._serialized_end = 12354 - _REATTACHOPTIONS._serialized_start = 12356 - _REATTACHOPTIONS._serialized_end = 12409 - _REATTACHEXECUTEREQUEST._serialized_start = 12412 - _REATTACHEXECUTEREQUEST._serialized_end = 12818 - _RELEASEEXECUTEREQUEST._serialized_start = 12821 - _RELEASEEXECUTEREQUEST._serialized_end = 13406 - _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_start = 13275 - _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_end = 13287 - _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_start = 13289 - _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_end = 13336 - _RELEASEEXECUTERESPONSE._serialized_start = 13409 - _RELEASEEXECUTERESPONSE._serialized_end = 13574 - _RELEASESESSIONREQUEST._serialized_start = 13577 - _RELEASESESSIONREQUEST._serialized_end = 13748 - _RELEASESESSIONRESPONSE._serialized_start = 13750 - _RELEASESESSIONRESPONSE._serialized_end = 13858 - _FETCHERRORDETAILSREQUEST._serialized_start = 13861 - _FETCHERRORDETAILSREQUEST._serialized_end = 14193 - _FETCHERRORDETAILSRESPONSE._serialized_start = 14196 - _FETCHERRORDETAILSRESPONSE._serialized_end = 15751 - _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 14425 - _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 14599 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 14602 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 14970 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_start = 14933 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_end = 14970 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 14973 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 15382 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start = 15284 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end = 15352 - _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 15385 - _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 15732 - _SPARKCONNECTSERVICE._serialized_start = 15754 - _SPARKCONNECTSERVICE._serialized_end = 16700 + _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_start = 7744 + _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_end = 7885 + _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_start = 7887 + _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_end = 7903 + _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS._serialized_start = 7906 + _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS._serialized_end = 8239 + _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS_STAGEINFO._serialized_start = 8062 + _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS_STAGEINFO._serialized_end = 8239 + _KEYVALUE._serialized_start = 8258 + _KEYVALUE._serialized_end = 8323 + _CONFIGREQUEST._serialized_start = 8326 + _CONFIGREQUEST._serialized_end = 9485 + _CONFIGREQUEST_OPERATION._serialized_start = 8634 + _CONFIGREQUEST_OPERATION._serialized_end = 9132 + _CONFIGREQUEST_SET._serialized_start = 9134 + _CONFIGREQUEST_SET._serialized_end = 9186 + _CONFIGREQUEST_GET._serialized_start = 9188 + _CONFIGREQUEST_GET._serialized_end = 9213 + _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 9215 + _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 9278 + _CONFIGREQUEST_GETOPTION._serialized_start = 9280 + _CONFIGREQUEST_GETOPTION._serialized_end = 9311 + _CONFIGREQUEST_GETALL._serialized_start = 9313 + _CONFIGREQUEST_GETALL._serialized_end = 9361 + _CONFIGREQUEST_UNSET._serialized_start = 9363 + _CONFIGREQUEST_UNSET._serialized_end = 9390 + _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 9392 + _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 9426 + _CONFIGRESPONSE._serialized_start = 9488 + _CONFIGRESPONSE._serialized_end = 9663 + _ADDARTIFACTSREQUEST._serialized_start = 9666 + _ADDARTIFACTSREQUEST._serialized_end = 10668 + _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 10141 + _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 10194 + _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 10196 + _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 10307 + _ADDARTIFACTSREQUEST_BATCH._serialized_start = 10309 + _ADDARTIFACTSREQUEST_BATCH._serialized_end = 10402 + _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 10405 + _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 10598 + _ADDARTIFACTSRESPONSE._serialized_start = 10671 + _ADDARTIFACTSRESPONSE._serialized_end = 10943 + _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 10862 + _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 10943 + _ARTIFACTSTATUSESREQUEST._serialized_start = 10946 + _ARTIFACTSTATUSESREQUEST._serialized_end = 11272 + _ARTIFACTSTATUSESRESPONSE._serialized_start = 11275 + _ARTIFACTSTATUSESRESPONSE._serialized_end = 11627 + _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_start = 11470 + _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_end = 11585 + _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_start = 11587 + _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_end = 11627 + _INTERRUPTREQUEST._serialized_start = 11630 + _INTERRUPTREQUEST._serialized_end = 12233 + _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_start = 12033 + _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_end = 12161 + _INTERRUPTRESPONSE._serialized_start = 12236 + _INTERRUPTRESPONSE._serialized_end = 12380 + _REATTACHOPTIONS._serialized_start = 12382 + _REATTACHOPTIONS._serialized_end = 12435 + _REATTACHEXECUTEREQUEST._serialized_start = 12438 + _REATTACHEXECUTEREQUEST._serialized_end = 12844 + _RELEASEEXECUTEREQUEST._serialized_start = 12847 + _RELEASEEXECUTEREQUEST._serialized_end = 13432 + _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_start = 13301 + _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_end = 13313 + _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_start = 13315 + _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_end = 13362 + _RELEASEEXECUTERESPONSE._serialized_start = 13435 + _RELEASEEXECUTERESPONSE._serialized_end = 13600 + _RELEASESESSIONREQUEST._serialized_start = 13603 + _RELEASESESSIONREQUEST._serialized_end = 13774 + _RELEASESESSIONRESPONSE._serialized_start = 13776 + _RELEASESESSIONRESPONSE._serialized_end = 13884 + _FETCHERRORDETAILSREQUEST._serialized_start = 13887 + _FETCHERRORDETAILSREQUEST._serialized_end = 14219 + _FETCHERRORDETAILSRESPONSE._serialized_start = 14222 + _FETCHERRORDETAILSRESPONSE._serialized_end = 15777 + _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 14451 + _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 14625 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 14628 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 14996 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_start = 14959 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_end = 14996 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 14999 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 15408 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start = 15310 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end = 15378 + _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 15411 + _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 15758 + _SPARKCONNECTSERVICE._serialized_start = 15780 + _SPARKCONNECTSERVICE._serialized_end = 16726 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index d22502f8839db..b76f2a7f4de34 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -1406,6 +1406,7 @@ class ExecutePlanResponse(google.protobuf.message.Message): NAME_FIELD_NUMBER: builtins.int VALUES_FIELD_NUMBER: builtins.int KEYS_FIELD_NUMBER: builtins.int + PLAN_ID_FIELD_NUMBER: builtins.int name: builtins.str @property def values( @@ -1417,6 +1418,7 @@ class ExecutePlanResponse(google.protobuf.message.Message): def keys( self, ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ... + plan_id: builtins.int def __init__( self, *, @@ -1426,11 +1428,12 @@ class ExecutePlanResponse(google.protobuf.message.Message): ] | None = ..., keys: collections.abc.Iterable[builtins.str] | None = ..., + plan_id: builtins.int = ..., ) -> None: ... def ClearField( self, field_name: typing_extensions.Literal[ - "keys", b"keys", "name", b"name", "values", b"values" + "keys", b"keys", "name", b"name", "plan_id", b"plan_id", "values", b"values" ], ) -> None: ... diff --git a/sql/api/src/main/scala/org/apache/spark/sql/ObservationBase.scala b/sql/api/src/main/scala/org/apache/spark/sql/ObservationBase.scala new file mode 100644 index 0000000000000..4789ae8975d12 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/ObservationBase.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.jdk.CollectionConverters.MapHasAsJava + +/** + * Helper class to simplify usage of `Dataset.observe(String, Column, Column*)`: + * + * {{{ + * // Observe row count (rows) and highest id (maxid) in the Dataset while writing it + * val observation = Observation("my metrics") + * val observed_ds = ds.observe(observation, count(lit(1)).as("rows"), max($"id").as("maxid")) + * observed_ds.write.parquet("ds.parquet") + * val metrics = observation.get + * }}} + * + * This collects the metrics while the first action is executed on the observed dataset. Subsequent + * actions do not modify the metrics returned by [[get]]. Retrieval of the metric via [[get]] + * blocks until the first action has finished and metrics become available. + * + * This class does not support streaming datasets. + * + * @param name name of the metric + * @since 3.3.0 + */ +abstract class ObservationBase(val name: String) { + + if (name.isEmpty) throw new IllegalArgumentException("Name must not be empty") + + @volatile protected var metrics: Option[Map[String, Any]] = None + + /** + * (Scala-specific) Get the observed metrics. This waits for the observed dataset to finish + * its first action. Only the result of the first action is available. Subsequent actions do not + * modify the result. + * + * @return the observed metrics as a `Map[String, Any]` + * @throws InterruptedException interrupted while waiting + */ + @throws[InterruptedException] + def get: Map[String, _] = { + synchronized { + // we need to loop as wait might return without us calling notify + // https://en.wikipedia.org/w/index.php?title=Spurious_wakeup&oldid=992601610 + while (this.metrics.isEmpty) { + wait() + } + } + + this.metrics.get + } + + /** + * (Java-specific) Get the observed metrics. This waits for the observed dataset to finish + * its first action. Only the result of the first action is available. Subsequent actions do not + * modify the result. + * + * @return the observed metrics as a `java.util.Map[String, Object]` + * @throws InterruptedException interrupted while waiting + */ + @throws[InterruptedException] + def getAsJava: java.util.Map[String, AnyRef] = { + get.map { case (key, value) => (key, value.asInstanceOf[Object]) }.asJava + } + + /** + * Get the observed metrics. This returns the metrics if they are available, otherwise an empty. + * + * @return the observed metrics as a `Map[String, Any]` + */ + @throws[InterruptedException] + private[sql] def getOrEmpty: Map[String, _] = { + synchronized { + if (metrics.isEmpty) { + wait(100) // Wait for 100ms to see if metrics are available + } + metrics.getOrElse(Map.empty) + } + } + + /** + * Set the observed metrics and notify all waiting threads to resume. + * + * @return `true` if all waiting threads were notified, `false` if otherwise. + */ + private[spark] def setMetricsAndNotify(metrics: Option[Map[String, Any]]): Boolean = { + synchronized { + this.metrics = metrics + if(metrics.isDefined) { + notifyAll() + true + } else { + false + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala b/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala index 104e7c101fd1c..30d5943c60922 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql import java.util.UUID -import scala.jdk.CollectionConverters.MapHasAsJava - import org.apache.spark.sql.catalyst.plans.logical.CollectMetrics import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.util.QueryExecutionListener @@ -47,9 +45,7 @@ import org.apache.spark.util.ArrayImplicits._ * @param name name of the metric * @since 3.3.0 */ -class Observation(val name: String) { - - if (name.isEmpty) throw new IllegalArgumentException("Name must not be empty") +class Observation(name: String) extends ObservationBase(name) { /** * Create an Observation instance without providing a name. This generates a random name. @@ -60,8 +56,6 @@ class Observation(val name: String) { @volatile private var dataframeId: Option[(SparkSession, Long)] = None - @volatile private var metrics: Option[Map[String, Any]] = None - /** * Attach this observation to the given [[Dataset]] to observe aggregation expressions. * @@ -83,55 +77,6 @@ class Observation(val name: String) { ds.observe(name, expr, exprs: _*) } - /** - * (Scala-specific) Get the observed metrics. This waits for the observed dataset to finish - * its first action. Only the result of the first action is available. Subsequent actions do not - * modify the result. - * - * @return the observed metrics as a `Map[String, Any]` - * @throws InterruptedException interrupted while waiting - */ - @throws[InterruptedException] - def get: Map[String, _] = { - synchronized { - // we need to loop as wait might return without us calling notify - // https://en.wikipedia.org/w/index.php?title=Spurious_wakeup&oldid=992601610 - while (this.metrics.isEmpty) { - wait() - } - } - - this.metrics.get - } - - /** - * (Java-specific) Get the observed metrics. This waits for the observed dataset to finish - * its first action. Only the result of the first action is available. Subsequent actions do not - * modify the result. - * - * @return the observed metrics as a `java.util.Map[String, Object]` - * @throws InterruptedException interrupted while waiting - */ - @throws[InterruptedException] - def getAsJava: java.util.Map[String, AnyRef] = { - get.map { case (key, value) => (key, value.asInstanceOf[Object])}.asJava - } - - /** - * Get the observed metrics. This returns the metrics if they are available, otherwise an empty. - * - * @return the observed metrics as a `Map[String, Any]` - */ - @throws[InterruptedException] - private[sql] def getOrEmpty: Map[String, _] = { - synchronized { - if (metrics.isEmpty) { - wait(100) // Wait for 100ms to see if metrics are available - } - metrics.getOrElse(Map.empty) - } - } - private[sql] def register(sparkSession: SparkSession, dataframeId: Long): Unit = { // makes this class thread-safe: // only the first thread entering this block can set sparkSession @@ -158,9 +103,8 @@ class Observation(val name: String) { case _ => false }) { val row = qe.observedMetrics.get(name) - this.metrics = row.map(r => r.getValuesMap[Any](r.schema.fieldNames.toImmutableArraySeq)) - if (metrics.isDefined) { - notifyAll() + val metrics = row.map(r => r.getValuesMap[Any](r.schema.fieldNames.toImmutableArraySeq)) + if (setMetricsAndNotify(metrics)) { unregister() } } From 4fb6624bd2cec0fec893ea0ac65b1a02c60384ec Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Thu, 9 May 2024 08:22:48 +0900 Subject: [PATCH 15/65] [SPARK-48205][PYTHON] Remove the private[sql] modifier for Python data sources ### What changes were proposed in this pull request? This PR removes the `private[sql]` modifier for Python data sources to make it consistent with UDFs and UDTFs. ### Why are the changes needed? ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #46487 from allisonwang-db/spark-48205-pyds-modifier. Authored-by: allisonwang-db Signed-off-by: Hyukjin Kwon --- .../scala/org/apache/spark/sql/DataSourceRegistration.scala | 2 +- sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala index 63cee8861c5a4..8ffdbb952b082 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.internal.SQLConf * Use `SparkSession.dataSource` to access this. */ @Evolving -private[sql] class DataSourceRegistration private[sql] (dataSourceManager: DataSourceManager) +class DataSourceRegistration private[sql] (dataSourceManager: DataSourceManager) extends Logging { protected[sql] def registerPython( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 5d85f070fbbe4..d5de74455dceb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -234,7 +234,7 @@ class SparkSession private( /** * A collection of methods for registering user-defined data sources. */ - private[sql] def dataSource: DataSourceRegistration = sessionState.dataSourceRegistration + def dataSource: DataSourceRegistration = sessionState.dataSourceRegistration /** * Returns a `StreamingQueryManager` that allows managing all the From 337f980f0073c8605ed2738186d2089a362b7f66 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 8 May 2024 16:47:01 -0700 Subject: [PATCH 16/65] [SPARK-48204][INFRA] Fix release script for Spark 4.0+ ### What changes were proposed in this pull request? Before Spark 4.0, Scala 2.12 was primary and Scala 2.13 was secondary. The release scripts build more packages (hadoop3, without-hadoop, pyspark, sparkr) for the primary Scala version but only one package for the secondary. However, Spark 4.0 removes Scala 2.12 support and the release script needs to be updated accordingly. ### Why are the changes needed? to make the release scripts work ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? manual. ### Was this patch authored or co-authored using generative AI tooling? no Closes #46484 from cloud-fan/re. Authored-by: Wenchen Fan Signed-off-by: Dongjoon Hyun --- dev/create-release/release-build.sh | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 75ec98464f3ec..b720a8fc93861 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -194,6 +194,8 @@ fi PUBLISH_SCALA_2_12=1 if [[ $SPARK_VERSION > "3.5.99" ]]; then PUBLISH_SCALA_2_12=0 + # There is no longer scala-2.13 profile since 4.0.0 + SCALA_2_13_PROFILES="" fi SCALA_2_12_PROFILES="-Pscala-2.12" @@ -345,21 +347,25 @@ if [[ "$1" == "package" ]]; then declare -A BINARY_PKGS_EXTRA BINARY_PKGS_EXTRA["hadoop3"]="withpip,withr" - if [[ $PUBLISH_SCALA_2_13 = 1 ]]; then - key="hadoop3-scala2.13" + # This is dead code as Scala 2.12 is no longer supported, but we keep it as a template for + # adding new Scala version support in the future. This secondary Scala version only has one + # binary package to avoid doubling the number of final packages. It doesn't build PySpark and + # SparkR as the primary Scala version will build them. + if [[ $PUBLISH_SCALA_2_12 = 1 ]]; then + key="hadoop3-scala2.12" args="-Phadoop-3 $HIVE_PROFILES" extra="" - if ! make_binary_release "$key" "$SCALA_2_13_PROFILES $args" "$extra" "2.13"; then + if ! make_binary_release "$key" "$SCALA_2_12_PROFILES $args" "$extra" "2.12"; then error "Failed to build $key package. Check logs for details." fi fi - if [[ $PUBLISH_SCALA_2_12 = 1 ]]; then + if [[ $PUBLISH_SCALA_2_13 = 1 ]]; then echo "Packages to build: ${!BINARY_PKGS_ARGS[@]}" for key in ${!BINARY_PKGS_ARGS[@]}; do args=${BINARY_PKGS_ARGS[$key]} extra=${BINARY_PKGS_EXTRA[$key]} - if ! make_binary_release "$key" "$SCALA_2_12_PROFILES $args" "$extra" "2.12"; then + if ! make_binary_release "$key" "$SCALA_2_13_PROFILES $args" "$extra" "2.13"; then error "Failed to build $key package. Check logs for details." fi done From 7e79e91dc8c531ee9135f0e32a9aa2e1f80c4bbf Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 9 May 2024 10:56:21 +0800 Subject: [PATCH 17/65] [SPARK-48197][SQL] Avoid assert error for invalid lambda function ### What changes were proposed in this pull request? `ExpressionBuilder` asserts all its input expressions to be resolved during lookup, which is not true as the analyzer rule `ResolveFunctions` can trigger function lookup even if the input expression contains unresolved lambda functions. This PR updates that assert to check non-lambda inputs only, and fail earlier if the input contains lambda functions. In the future, if we use `ExpressionBuilder` to register higher-order functions, we can relax it. ### Why are the changes needed? better error message ### Does this PR introduce _any_ user-facing change? no, only changes error message ### How was this patch tested? new test ### Was this patch authored or co-authored using generative AI tooling? no Closes #46475 from cloud-fan/minor. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../catalyst/analysis/FunctionRegistry.scala | 9 +++++++- .../plans/logical/FunctionBuilderBase.scala | 2 ++ .../ansi/higher-order-functions.sql.out | 20 +++++++++++++++++ .../higher-order-functions.sql.out | 20 +++++++++++++++++ .../inputs/higher-order-functions.sql | 2 ++ .../ansi/higher-order-functions.sql.out | 22 +++++++++++++++++++ .../results/higher-order-functions.sql.out | 22 +++++++++++++++++++ 7 files changed, 96 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 6565591b79524..f37f47c13ed45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -955,7 +955,14 @@ object FunctionRegistry { since: Option[String] = None): (String, (ExpressionInfo, FunctionBuilder)) = { val info = FunctionRegistryBase.expressionInfo[T](name, since) val funcBuilder = (expressions: Seq[Expression]) => { - assert(expressions.forall(_.resolved), "function arguments must be resolved.") + val (lambdas, others) = expressions.partition(_.isInstanceOf[LambdaFunction]) + if (lambdas.nonEmpty && !builder.supportsLambda) { + throw new AnalysisException( + errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION", + messageParameters = Map( + "class" -> builder.getClass.getCanonicalName)) + } + assert(others.forall(_.resolved), "function arguments must be resolved.") val rearrangedExpressions = rearrangeExpressions(name, builder, expressions) val expr = builder.build(name, rearrangedExpressions) if (setAlias) expr.setTagValue(FUNC_ALIAS, name) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala index 7e04af190e4aa..0aa73f1939e10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala @@ -70,6 +70,8 @@ trait FunctionBuilderBase[T] { } def build(funcName: String, expressions: Seq[Expression]): T + + def supportsLambda: Boolean = false } object NamedParametersSupport { diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/higher-order-functions.sql.out index 693cb2a046319..a772a2c92e672 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/higher-order-functions.sql.out @@ -35,6 +35,26 @@ org.apache.spark.sql.AnalysisException } +-- !query +select ceil(x -> x) as v +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION", + "sqlState" : "42K0D", + "messageParameters" : { + "class" : "org.apache.spark.sql.catalyst.expressions.CeilExpressionBuilder$" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 19, + "fragment" : "ceil(x -> x)" + } ] +} + + -- !query select transform(zs, z -> z) as v from nested -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/higher-order-functions.sql.out index ec6d7271cc235..c82ba7d062016 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/higher-order-functions.sql.out @@ -35,6 +35,26 @@ org.apache.spark.sql.AnalysisException } +-- !query +select ceil(x -> x) as v +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION", + "sqlState" : "42K0D", + "messageParameters" : { + "class" : "org.apache.spark.sql.catalyst.expressions.CeilExpressionBuilder$" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 19, + "fragment" : "ceil(x -> x)" + } ] +} + + -- !query select transform(zs, z -> z) as v from nested -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql index 7925a21de04cd..37081de012e98 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql @@ -11,6 +11,8 @@ create or replace temporary view nested as values -- Only allow lambda's in higher order functions. select upper(x -> x) as v; +-- Also test functions registered with `ExpressionBuilder`. +select ceil(x -> x) as v; -- Identity transform an array select transform(zs, z -> z) as v from nested; diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/higher-order-functions.sql.out index ee4525285a9be..7bfc35a61e092 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/higher-order-functions.sql.out @@ -33,6 +33,28 @@ org.apache.spark.sql.AnalysisException } +-- !query +select ceil(x -> x) as v +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION", + "sqlState" : "42K0D", + "messageParameters" : { + "class" : "org.apache.spark.sql.catalyst.expressions.CeilExpressionBuilder$" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 19, + "fragment" : "ceil(x -> x)" + } ] +} + + -- !query select transform(zs, z -> z) as v from nested -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out index ee4525285a9be..7bfc35a61e092 100644 --- a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -33,6 +33,28 @@ org.apache.spark.sql.AnalysisException } +-- !query +select ceil(x -> x) as v +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION", + "sqlState" : "42K0D", + "messageParameters" : { + "class" : "org.apache.spark.sql.catalyst.expressions.CeilExpressionBuilder$" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 19, + "fragment" : "ceil(x -> x)" + } ] +} + + -- !query select transform(zs, z -> z) as v from nested -- !query schema From 5891b20ef492e3dad31ff851770d9c4f9c7c4de4 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Wed, 8 May 2024 21:56:55 -0700 Subject: [PATCH 18/65] [SPARK-47186][TESTS][FOLLOWUP] Correct the name of spark.test.docker.connectionTimeout ### What changes were proposed in this pull request? This PR adds a followup of SPARK-47186 to correct the name of spark.test.docker.connectionTimeout ### Why are the changes needed? test bugfix ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #46495 from yaooqinn/SPARK-47186-FF. Authored-by: Kent Yao Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala index ded7bb3a6bf65..8d17e0b4e36e6 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala @@ -115,7 +115,7 @@ abstract class DockerJDBCIntegrationSuite protected val startContainerTimeout: Long = timeStringAsSeconds(sys.props.getOrElse("spark.test.docker.startContainerTimeout", "5min")) protected val connectionTimeout: PatienceConfiguration.Timeout = { - val timeoutStr = sys.props.getOrElse("spark.test.docker.conn", "5min") + val timeoutStr = sys.props.getOrElse("spark.test.docker.connectionTimeout", "5min") timeout(timeStringAsSeconds(timeoutStr).seconds) } From 85a6e35d834eabef0bdcf9ff5bcf16eea669c828 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Wed, 8 May 2024 22:44:39 -0700 Subject: [PATCH 19/65] [SPARK-48182][SQL] SQL (java side): Migrate `error/warn/info` with variables to structured logging framework ### What changes were proposed in this pull request? The pr aims to 1.migrate `error/warn/info` in module `SQL` with variables to `structured logging framework` for java side. 2.convert all dependencies on `org.slf4j.Logger & org.slf4j.LoggerFactory` to `org.apache.spark.internal.Logger & org.apache.spark.internal.LoggerFactory`, in order to completely `prohibit` importing `org.slf4j.Logger & org.slf4j.LoggerFactory` in java code later. ### Why are the changes needed? To enhance Apache Spark's logging system by implementing structured logging. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46450 from panbingkun/sql_java_sl. Authored-by: panbingkun Signed-off-by: Gengliang Wang --- .../org/apache/spark/internal/Logger.java | 4 ++ .../org/apache/spark/internal/LogKey.scala | 13 +++++ .../expressions/RowBasedKeyValueBatch.java | 11 +++-- .../sql/util/CaseInsensitiveStringMap.java | 18 ++++--- .../apache/hive/service/AbstractService.java | 13 +++-- .../apache/hive/service/CompositeService.java | 14 ++++-- .../org/apache/hive/service/CookieSigner.java | 5 +- .../hive/service/ServiceOperations.java | 12 +++-- .../org/apache/hive/service/ServiceUtils.java | 2 +- .../hive/service/auth/HiveAuthFactory.java | 15 +++--- .../hive/service/auth/HttpAuthUtils.java | 12 +++-- .../service/auth/TSetIpAddressProcessor.java | 7 +-- .../apache/hive/service/cli/CLIService.java | 21 +++++--- .../hive/service/cli/ColumnBasedSet.java | 9 ++-- .../operation/ClassicTableTypeMapping.java | 13 +++-- .../hive/service/cli/operation/Operation.java | 28 ++++++----- .../cli/operation/OperationManager.java | 10 ++-- .../service/cli/session/HiveSessionImpl.java | 49 +++++++++++-------- .../service/cli/session/SessionManager.java | 49 +++++++++++-------- .../service/cli/thrift/ThriftCLIService.java | 16 +++--- .../service/cli/thrift/ThriftHttpServlet.java | 14 ++++-- .../hive/service/server/HiveServer2.java | 12 +++-- .../server/ThreadWithGarbageCleanup.java | 5 +- .../thriftserver/SparkSQLCLIService.scala | 2 +- 24 files changed, 222 insertions(+), 132 deletions(-) diff --git a/common/utils/src/main/java/org/apache/spark/internal/Logger.java b/common/utils/src/main/java/org/apache/spark/internal/Logger.java index d8ab26424bae5..7c54e912b189a 100644 --- a/common/utils/src/main/java/org/apache/spark/internal/Logger.java +++ b/common/utils/src/main/java/org/apache/spark/internal/Logger.java @@ -193,4 +193,8 @@ static MessageThrowable of(String message, Throwable throwable) { return new MessageThrowable(message, throwable); } } + + public org.slf4j.Logger getSlf4jLogger() { + return slf4jLogger; + } } diff --git a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala index 14e822c6349f3..78be240619405 100644 --- a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala +++ b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala @@ -230,6 +230,7 @@ object LogKeys { case object FROM_TIME extends LogKey case object FUNCTION_NAME extends LogKey case object FUNCTION_PARAMETER extends LogKey + case object GLOBAL_INIT_FILE extends LogKey case object GLOBAL_WATERMARK extends LogKey case object GROUP_BY_EXPRS extends LogKey case object GROUP_ID extends LogKey @@ -275,6 +276,7 @@ object LogKeys { case object KAFKA_RECORDS_PULLED_COUNT extends LogKey case object KEY extends LogKey case object KEYTAB extends LogKey + case object KEYTAB_FILE extends LogKey case object LABEL_COLUMN extends LogKey case object LARGEST_CLUSTER_INDEX extends LogKey case object LAST_ACCESS_TIME extends LogKey @@ -290,6 +292,7 @@ object LogKeys { case object LOADED_VERSION extends LogKey case object LOAD_FACTOR extends LogKey case object LOAD_TIME extends LogKey + case object LOCAL_SCRATCH_DIR extends LogKey case object LOCATION extends LogKey case object LOGICAL_PLAN_COLUMNS extends LogKey case object LOGICAL_PLAN_LEAVES extends LogKey @@ -411,6 +414,8 @@ object LogKeys { case object OLD_GENERATION_GC extends LogKey case object OLD_VALUE extends LogKey case object OPEN_COST_IN_BYTES extends LogKey + case object OPERATION_HANDLE extends LogKey + case object OPERATION_HANDLE_IDENTIFIER extends LogKey case object OPTIMIZED_PLAN_COLUMNS extends LogKey case object OPTIMIZER_CLASS_NAME extends LogKey case object OPTIONS extends LogKey @@ -458,6 +463,7 @@ object LogKeys { case object PROCESSING_TIME extends LogKey case object PRODUCER_ID extends LogKey case object PROPERTY_NAME extends LogKey + case object PROTOCOL_VERSION extends LogKey case object PROVIDER extends LogKey case object PUSHED_FILTERS extends LogKey case object PVC_METADATA_NAME extends LogKey @@ -523,9 +529,11 @@ object LogKeys { case object SERVER_NAME extends LogKey case object SERVICE_NAME extends LogKey case object SERVLET_CONTEXT_HANDLER_PATH extends LogKey + case object SESSION_HANDLE extends LogKey case object SESSION_HOLD_INFO extends LogKey case object SESSION_ID extends LogKey case object SESSION_KEY extends LogKey + case object SET_CLIENT_INFO_REQUEST extends LogKey case object SHARD_ID extends LogKey case object SHELL_COMMAND extends LogKey case object SHUFFLE_BLOCK_INFO extends LogKey @@ -578,6 +586,7 @@ object LogKeys { case object SUBSAMPLING_RATE extends LogKey case object SUB_QUERY extends LogKey case object TABLE_NAME extends LogKey + case object TABLE_TYPE extends LogKey case object TABLE_TYPES extends LogKey case object TARGET_NUM_EXECUTOR extends LogKey case object TARGET_NUM_EXECUTOR_DELTA extends LogKey @@ -595,6 +604,9 @@ object LogKeys { case object THREAD extends LogKey case object THREAD_ID extends LogKey case object THREAD_NAME extends LogKey + case object THREAD_POOL_KEEPALIVE_TIME extends LogKey + case object THREAD_POOL_SIZE extends LogKey + case object THREAD_POOL_WAIT_QUEUE_SIZE extends LogKey case object TID extends LogKey case object TIME extends LogKey case object TIMEOUT extends LogKey @@ -602,6 +614,7 @@ object LogKeys { case object TIMESTAMP extends LogKey case object TIME_UNITS extends LogKey case object TIP extends LogKey + case object TOKEN extends LogKey case object TOKEN_KIND extends LogKey case object TOKEN_REGEX extends LogKey case object TOPIC extends LogKey diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java index 6a74f64d44849..be7e682a3bdf5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java @@ -19,16 +19,16 @@ import java.io.Closeable; import java.io.IOException; +import org.apache.spark.internal.Logger; +import org.apache.spark.internal.LoggerFactory; +import org.apache.spark.internal.LogKeys; +import org.apache.spark.internal.MDC; import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.SparkOutOfMemoryError; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - - /** * RowBasedKeyValueBatch stores key value pairs in contiguous memory region. * @@ -127,7 +127,8 @@ private boolean acquirePage(long requiredSize) { try { page = allocatePage(requiredSize); } catch (SparkOutOfMemoryError e) { - logger.warn("Failed to allocate page ({} bytes).", requiredSize); + logger.warn("Failed to allocate page ({} bytes).", + MDC.of(LogKeys.PAGE_SIZE$.MODULE$, requiredSize)); return false; } base = page.getBaseObject(); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/util/CaseInsensitiveStringMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/util/CaseInsensitiveStringMap.java index 00a3de692fbf4..d66524d841ca6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/util/CaseInsensitiveStringMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/util/CaseInsensitiveStringMap.java @@ -17,12 +17,6 @@ package org.apache.spark.sql.util; -import org.apache.spark.SparkIllegalArgumentException; -import org.apache.spark.SparkUnsupportedOperationException; -import org.apache.spark.annotation.Experimental; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.util.Collection; import java.util.Collections; import java.util.HashMap; @@ -31,6 +25,14 @@ import java.util.Objects; import java.util.Set; +import org.apache.spark.annotation.Experimental; +import org.apache.spark.internal.Logger; +import org.apache.spark.internal.LoggerFactory; +import org.apache.spark.internal.LogKeys; +import org.apache.spark.internal.MDC; +import org.apache.spark.SparkIllegalArgumentException; +import org.apache.spark.SparkUnsupportedOperationException; + /** * Case-insensitive map of string keys to string values. *

@@ -59,8 +61,8 @@ public CaseInsensitiveStringMap(Map originalMap) { for (Map.Entry entry : originalMap.entrySet()) { String key = toLowerCase(entry.getKey()); if (delegate.containsKey(key)) { - logger.warn("Converting duplicated key " + entry.getKey() + - " into CaseInsensitiveStringMap."); + logger.warn("Converting duplicated key {} into CaseInsensitiveStringMap.", + MDC.of(LogKeys.KEY$.MODULE$, entry.getKey())); } delegate.put(key, entry.getValue()); } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/AbstractService.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/AbstractService.java index 6481cf15075a7..009b9f253ce0d 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/AbstractService.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/AbstractService.java @@ -21,8 +21,11 @@ import java.util.List; import org.apache.hadoop.hive.conf.HiveConf; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; + +import org.apache.spark.internal.Logger; +import org.apache.spark.internal.LoggerFactory; +import org.apache.spark.internal.LogKeys; +import org.apache.spark.internal.MDC; /** * AbstractService. @@ -85,7 +88,7 @@ public synchronized void init(HiveConf hiveConf) { ensureCurrentState(STATE.NOTINITED); this.hiveConf = hiveConf; changeState(STATE.INITED); - LOG.info("Service:" + getName() + " is inited."); + LOG.info("Service:{} is inited.", MDC.of(LogKeys.SERVICE_NAME$.MODULE$, getName())); } /** @@ -100,7 +103,7 @@ public synchronized void start() { startTime = System.currentTimeMillis(); ensureCurrentState(STATE.INITED); changeState(STATE.STARTED); - LOG.info("Service:" + getName() + " is started."); + LOG.info("Service:{} is started.", MDC.of(LogKeys.SERVICE_NAME$.MODULE$, getName())); } /** @@ -121,7 +124,7 @@ public synchronized void stop() { } ensureCurrentState(STATE.STARTED); changeState(STATE.STOPPED); - LOG.info("Service:" + getName() + " is stopped."); + LOG.info("Service:{} is stopped.", MDC.of(LogKeys.SERVICE_NAME$.MODULE$, getName())); } @Override diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/CompositeService.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/CompositeService.java index 55c1aa52b95ca..ecd9de8154b31 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/CompositeService.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/CompositeService.java @@ -23,8 +23,11 @@ import java.util.List; import org.apache.hadoop.hive.conf.HiveConf; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; + +import org.apache.spark.internal.Logger; +import org.apache.spark.internal.LoggerFactory; +import org.apache.spark.internal.LogKeys; +import org.apache.spark.internal.MDC; /** * CompositeService. @@ -70,7 +73,7 @@ public synchronized void start() { } super.start(); } catch (Throwable e) { - LOG.error("Error starting services " + getName(), e); + LOG.error("Error starting services {}", e, MDC.of(LogKeys.SERVICE_NAME$.MODULE$, getName())); // Note that the state of the failed service is still INITED and not // STARTED. Even though the last service is not started completely, still // call stop() on all services including failed service to make sure cleanup @@ -100,7 +103,7 @@ private synchronized void stop(int numOfServicesStarted) { try { service.stop(); } catch (Throwable t) { - LOG.info("Error stopping " + service.getName(), t); + LOG.info("Error stopping {}", t, MDC.of(LogKeys.SERVICE_NAME$.MODULE$, service.getName())); } } } @@ -123,7 +126,8 @@ public void run() { // Stop the Composite Service compositeService.stop(); } catch (Throwable t) { - LOG.info("Error stopping " + compositeService.getName(), t); + LOG.info("Error stopping {}", t, + MDC.of(LogKeys.SERVICE_NAME$.MODULE$, compositeService.getName())); } } } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/CookieSigner.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/CookieSigner.java index 4b8d2cb1536cd..25e0316d5e9c3 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/CookieSigner.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/CookieSigner.java @@ -21,8 +21,9 @@ import java.security.NoSuchAlgorithmException; import org.apache.commons.codec.binary.Base64; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; + +import org.apache.spark.internal.Logger; +import org.apache.spark.internal.LoggerFactory; /** * The cookie signer generates a signature based on SHA digest diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceOperations.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceOperations.java index 434676aa8d215..d947f01681bea 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceOperations.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceOperations.java @@ -18,8 +18,11 @@ package org.apache.hive.service; import org.apache.hadoop.hive.conf.HiveConf; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; + +import org.apache.spark.internal.Logger; +import org.apache.spark.internal.LoggerFactory; +import org.apache.spark.internal.LogKeys; +import org.apache.spark.internal.MDC; /** * ServiceOperations. @@ -129,9 +132,8 @@ public static Exception stopQuietly(Service service) { try { stop(service); } catch (Exception e) { - LOG.warn("When stopping the service " + service.getName() - + " : " + e, - e); + LOG.warn("When stopping the service {}", e, + MDC.of(LogKeys.SERVICE_NAME$.MODULE$, service.getName())); return e; } return null; diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceUtils.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceUtils.java index 7552bda57dc0b..82ef4b9f9ce70 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceUtils.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceUtils.java @@ -18,7 +18,7 @@ import java.io.IOException; -import org.slf4j.Logger; +import org.apache.spark.internal.Logger; public class ServiceUtils { diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java index c48f4e3ec7b09..b570e88e2bc5b 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java @@ -42,16 +42,19 @@ import org.apache.thrift.TProcessorFactory; import org.apache.thrift.transport.TTransportException; import org.apache.thrift.transport.TTransportFactory; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; + +import org.apache.spark.internal.Logger; +import org.apache.spark.internal.LoggerFactory; +import org.apache.spark.internal.LogKeys; +import org.apache.spark.internal.MDC; /** * This class helps in some aspects of authentication. It creates the proper Thrift classes for the * given configuration as well as helps with authenticating requests. */ public class HiveAuthFactory { - private static final Logger LOG = LoggerFactory.getLogger(HiveAuthFactory.class); + private static final Logger LOG = LoggerFactory.getLogger(HiveAuthFactory.class); public enum AuthTypes { NOSASL("NOSASL"), @@ -285,9 +288,9 @@ public String verifyDelegationToken(String delegationToken) throws HiveSQLExcept try { return delegationTokenManager.verifyDelegationToken(delegationToken); } catch (IOException e) { - String msg = "Error verifying delegation token " + delegationToken; - LOG.error(msg, e); - throw new HiveSQLException(msg, "08S01", e); + String msg = "Error verifying delegation token"; + LOG.error(msg + " {}", e, MDC.of(LogKeys.TOKEN$.MODULE$, delegationToken)); + throw new HiveSQLException(msg + delegationToken, "08S01", e); } } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HttpAuthUtils.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HttpAuthUtils.java index 08a8258db06f2..0bfe361104dea 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HttpAuthUtils.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HttpAuthUtils.java @@ -39,8 +39,11 @@ import org.ietf.jgss.GSSManager; import org.ietf.jgss.GSSName; import org.ietf.jgss.Oid; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; + +import org.apache.spark.internal.Logger; +import org.apache.spark.internal.LoggerFactory; +import org.apache.spark.internal.LogKeys; +import org.apache.spark.internal.MDC; /** * Utility functions for HTTP mode authentication. @@ -109,7 +112,8 @@ public static String getUserNameFromCookieToken(String tokenStr) { Map map = splitCookieToken(tokenStr); if (!map.keySet().equals(COOKIE_ATTRIBUTES)) { - LOG.error("Invalid token with missing attributes " + tokenStr); + LOG.error("Invalid token with missing attributes {}", + MDC.of(LogKeys.TOKEN$.MODULE$, tokenStr)); return null; } return map.get(COOKIE_CLIENT_USER_NAME); @@ -129,7 +133,7 @@ private static Map splitCookieToken(String tokenStr) { String part = st.nextToken(); int separator = part.indexOf(COOKIE_KEY_VALUE_SEPARATOR); if (separator == -1) { - LOG.error("Invalid token string " + tokenStr); + LOG.error("Invalid token string {}", MDC.of(LogKeys.TOKEN$.MODULE$, tokenStr)); return null; } String key = part.substring(0, separator); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/TSetIpAddressProcessor.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/TSetIpAddressProcessor.java index 1205d21be6be6..8e7d8e60c176b 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/TSetIpAddressProcessor.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/TSetIpAddressProcessor.java @@ -25,8 +25,9 @@ import org.apache.thrift.transport.TSaslServerTransport; import org.apache.thrift.transport.TSocket; import org.apache.thrift.transport.TTransport; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; + +import org.apache.spark.internal.Logger; +import org.apache.spark.internal.LoggerFactory; /** * This class is responsible for setting the ipAddress for operations executed via HiveServer2. @@ -38,7 +39,7 @@ */ public class TSetIpAddressProcessor extends TCLIService.Processor { - private static final Logger LOGGER = LoggerFactory.getLogger(TSetIpAddressProcessor.class.getName()); + private static final Logger LOGGER = LoggerFactory.getLogger(TSetIpAddressProcessor.class); public TSetIpAddressProcessor(Iface iface) { super(iface); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIService.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIService.java index caccb0c4b76f7..e612b34d7bdf7 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIService.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIService.java @@ -49,8 +49,11 @@ import org.apache.hive.service.rpc.thrift.TRowSet; import org.apache.hive.service.rpc.thrift.TTableSchema; import org.apache.hive.service.server.HiveServer2; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; + +import org.apache.spark.internal.Logger; +import org.apache.spark.internal.LoggerFactory; +import org.apache.spark.internal.LogKeys; +import org.apache.spark.internal.MDC; /** * CLIService. @@ -99,8 +102,9 @@ public synchronized void init(HiveConf hiveConf) { String principal = hiveConf.getVar(ConfVars.HIVE_SERVER2_SPNEGO_PRINCIPAL); String keyTabFile = hiveConf.getVar(ConfVars.HIVE_SERVER2_SPNEGO_KEYTAB); if (principal.isEmpty() || keyTabFile.isEmpty()) { - LOG.info("SPNego httpUGI not created, spNegoPrincipal: " + principal + - ", ketabFile: " + keyTabFile); + LOG.info("SPNego httpUGI not created, spNegoPrincipal: {}, keytabFile: {}", + MDC.of(LogKeys.PRINCIPAL$.MODULE$, principal), + MDC.of(LogKeys.KEYTAB_FILE$.MODULE$, keyTabFile)); } else { try { this.httpUGI = HiveAuthFactory.loginFromSpnegoKeytabAndReturnUGI(hiveConf); @@ -457,7 +461,8 @@ public OperationStatus getOperationStatus(OperationHandle opHandle) LOG.trace(opHandle + ": The background operation was cancelled", e); } catch (ExecutionException e) { // The background operation thread was aborted - LOG.warn(opHandle + ": The background operation was aborted", e); + LOG.warn("{}: The background operation was aborted", e, + MDC.of(LogKeys.OPERATION_HANDLE$.MODULE$, opHandle)); } catch (InterruptedException e) { // No op, this thread was interrupted // In this case, the call might return sooner than long polling timeout @@ -551,7 +556,7 @@ public String getDelegationToken(SessionHandle sessionHandle, HiveAuthFactory au String owner, String renewer) throws HiveSQLException { String delegationToken = sessionManager.getSession(sessionHandle) .getDelegationToken(authFactory, owner, renewer); - LOG.info(sessionHandle + ": getDelegationToken()"); + LOG.info("{}: getDelegationToken()", MDC.of(LogKeys.SESSION_HANDLE$.MODULE$, sessionHandle)); return delegationToken; } @@ -559,14 +564,14 @@ public String getDelegationToken(SessionHandle sessionHandle, HiveAuthFactory au public void cancelDelegationToken(SessionHandle sessionHandle, HiveAuthFactory authFactory, String tokenStr) throws HiveSQLException { sessionManager.getSession(sessionHandle).cancelDelegationToken(authFactory, tokenStr); - LOG.info(sessionHandle + ": cancelDelegationToken()"); + LOG.info("{}: cancelDelegationToken()", MDC.of(LogKeys.SESSION_HANDLE$.MODULE$, sessionHandle)); } @Override public void renewDelegationToken(SessionHandle sessionHandle, HiveAuthFactory authFactory, String tokenStr) throws HiveSQLException { sessionManager.getSession(sessionHandle).renewDelegationToken(authFactory, tokenStr); - LOG.info(sessionHandle + ": renewDelegationToken()"); + LOG.info("{}: renewDelegationToken()", MDC.of(LogKeys.SESSION_HANDLE$.MODULE$, sessionHandle)); } @Override diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/ColumnBasedSet.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/ColumnBasedSet.java index 629d9abdac2c0..f6a269e99251d 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/ColumnBasedSet.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/ColumnBasedSet.java @@ -30,8 +30,11 @@ import org.apache.thrift.protocol.TCompactProtocol; import org.apache.thrift.protocol.TProtocol; import org.apache.thrift.transport.TIOStreamTransport; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; + +import org.apache.spark.internal.Logger; +import org.apache.spark.internal.LoggerFactory; +import org.apache.spark.internal.LogKeys; +import org.apache.spark.internal.MDC; /** * ColumnBasedSet. @@ -68,7 +71,7 @@ public ColumnBasedSet(TRowSet tRowSet) throws TException { try { tvalue.read(protocol); } catch (TException e) { - LOG.error(e.getMessage(), e); + LOG.error("{}", e, MDC.of(LogKeys.ERROR$.MODULE$, e.getMessage())); throw new TException("Error reading column value from the row set blob", e); } columns.add(new ColumnBuffer(tvalue)); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ClassicTableTypeMapping.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ClassicTableTypeMapping.java index 96c16beac7c4d..3876632211715 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ClassicTableTypeMapping.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ClassicTableTypeMapping.java @@ -28,8 +28,11 @@ import com.google.common.collect.Iterables; import com.google.common.collect.Multimap; import org.apache.hadoop.hive.metastore.TableType; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; + +import org.apache.spark.internal.Logger; +import org.apache.spark.internal.LoggerFactory; +import org.apache.spark.internal.LogKeys; +import org.apache.spark.internal.MDC; /** * ClassicTableTypeMapping. @@ -69,7 +72,8 @@ public ClassicTableTypeMapping() { public String[] mapToHiveType(String clientTypeName) { Collection hiveTableType = clientToHiveMap.get(clientTypeName.toUpperCase()); if (hiveTableType == null) { - LOG.warn("Not supported client table type " + clientTypeName); + LOG.warn("Not supported client table type {}", + MDC.of(LogKeys.TABLE_TYPE$.MODULE$, clientTypeName)); return new String[] {clientTypeName}; } return Iterables.toArray(hiveTableType, String.class); @@ -79,7 +83,8 @@ public String[] mapToHiveType(String clientTypeName) { public String mapToClientType(String hiveTypeName) { String clientTypeName = hiveToClientMap.get(hiveTypeName); if (clientTypeName == null) { - LOG.warn("Invalid hive table type " + hiveTypeName); + LOG.warn("Invalid hive table type {}", + MDC.of(LogKeys.TABLE_TYPE$.MODULE$, hiveTypeName)); return hiveTypeName; } return clientTypeName; diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/Operation.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/Operation.java index ad42925207d69..135420508e21e 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/Operation.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/Operation.java @@ -38,15 +38,18 @@ import org.apache.hive.service.rpc.thrift.TProtocolVersion; import org.apache.hive.service.rpc.thrift.TRowSet; import org.apache.hive.service.rpc.thrift.TTableSchema; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; + +import org.apache.spark.internal.Logger; +import org.apache.spark.internal.LoggerFactory; +import org.apache.spark.internal.LogKeys; +import org.apache.spark.internal.MDC; public abstract class Operation { protected final HiveSession parentSession; private OperationState state = OperationState.INITIALIZED; private final OperationHandle opHandle; private HiveConf configuration; - public static final Logger LOG = LoggerFactory.getLogger(Operation.class.getName()); + public static final Logger LOG = LoggerFactory.getLogger(Operation.class); public static final FetchOrientation DEFAULT_FETCH_ORIENTATION = FetchOrientation.FETCH_NEXT; public static final long DEFAULT_FETCH_MAX_ROWS = 100; protected boolean hasResultSet; @@ -208,8 +211,8 @@ protected void createOperationLog() { // create log file try { if (operationLogFile.exists()) { - LOG.warn("The operation log file should not exist, but it is already there: " + - operationLogFile.getAbsolutePath()); + LOG.warn("The operation log file should not exist, but it is already there: {}", + MDC.of(LogKeys.PATH$.MODULE$, operationLogFile.getAbsolutePath())); operationLogFile.delete(); } if (!operationLogFile.createNewFile()) { @@ -217,13 +220,15 @@ protected void createOperationLog() { // If it can be read/written, keep its contents and use it. if (!operationLogFile.canRead() || !operationLogFile.canWrite()) { LOG.warn("The already existed operation log file cannot be recreated, " + - "and it cannot be read or written: " + operationLogFile.getAbsolutePath()); + "and it cannot be read or written: {}", + MDC.of(LogKeys.PATH$.MODULE$, operationLogFile.getAbsolutePath())); isOperationLogEnabled = false; return; } } } catch (Exception e) { - LOG.warn("Unable to create operation log file: " + operationLogFile.getAbsolutePath(), e); + LOG.warn("Unable to create operation log file: {}", e, + MDC.of(LogKeys.PATH$.MODULE$, operationLogFile.getAbsolutePath())); isOperationLogEnabled = false; return; } @@ -232,8 +237,8 @@ protected void createOperationLog() { try { operationLog = new OperationLog(opHandle.toString(), operationLogFile, parentSession.getHiveConf()); } catch (FileNotFoundException e) { - LOG.warn("Unable to instantiate OperationLog object for operation: " + - opHandle, e); + LOG.warn("Unable to instantiate OperationLog object for operation: {}", e, + MDC.of(LogKeys.OPERATION_HANDLE$.MODULE$, opHandle)); isOperationLogEnabled = false; return; } @@ -283,8 +288,9 @@ public void run() throws HiveSQLException { protected void cleanupOperationLog() { if (isOperationLogEnabled) { if (operationLog == null) { - LOG.error("Operation [ " + opHandle.getHandleIdentifier() + " ] " - + "logging is enabled, but its OperationLog object cannot be found."); + LOG.error("Operation [ {} ] logging is enabled, " + + "but its OperationLog object cannot be found.", + MDC.of(LogKeys.OPERATION_HANDLE_IDENTIFIER$.MODULE$, opHandle.getHandleIdentifier())); } else { operationLog.close(); } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/OperationManager.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/OperationManager.java index bb68c840496ad..1498cb4907f1f 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/OperationManager.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/OperationManager.java @@ -40,8 +40,11 @@ import org.apache.hive.service.rpc.thrift.TRowSet; import org.apache.hive.service.rpc.thrift.TTableSchema; import org.apache.logging.log4j.core.Appender; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; + +import org.apache.spark.internal.Logger; +import org.apache.spark.internal.LoggerFactory; +import org.apache.spark.internal.LogKeys; +import org.apache.spark.internal.MDC; /** * OperationManager. @@ -289,7 +292,8 @@ public List removeExpiredOperations(OperationHandle[] handles) { for (OperationHandle handle : handles) { Operation operation = removeTimedOutOperation(handle); if (operation != null) { - LOG.warn("Operation " + handle + " is timed-out and will be closed"); + LOG.warn("Operation {} is timed-out and will be closed", + MDC.of(LogKeys.OPERATION_HANDLE$.MODULE$, handle)); removed.add(operation); } } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java index e00d2705d4172..e073fa4713bfb 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java @@ -69,8 +69,11 @@ import org.apache.hive.service.rpc.thrift.TRowSet; import org.apache.hive.service.rpc.thrift.TTableSchema; import org.apache.hive.service.server.ThreadWithGarbageCleanup; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; + +import org.apache.spark.internal.Logger; +import org.apache.spark.internal.LoggerFactory; +import org.apache.spark.internal.LogKeys; +import org.apache.spark.internal.MDC; import static org.apache.hadoop.hive.conf.SystemVariables.ENV_PREFIX; import static org.apache.hadoop.hive.conf.SystemVariables.HIVECONF_PREFIX; @@ -116,7 +119,7 @@ public HiveSessionImpl(TProtocolVersion protocol, String username, String passwo ShimLoader.getHadoopShims().refreshDefaultQueue(hiveConf, username); } } catch (IOException e) { - LOG.warn("Error setting scheduler queue: " + e, e); + LOG.warn("Error setting scheduler queue: ", e); } // Set an explicit session name to control the download directory name hiveConf.set("hive.session.id", @@ -146,8 +149,8 @@ public void open(Map sessionConfMap) throws HiveSQLException { sessionState.loadAuxJars(); sessionState.loadReloadableAuxJars(); } catch (IOException e) { - String msg = "Failed to load reloadable jar file path: " + e; - LOG.error(msg, e); + String msg = "Failed to load reloadable jar file path."; + LOG.error("{}", e, MDC.of(LogKeys.ERROR$.MODULE$, msg)); throw new HiveSQLException(msg, e); } // Process global init file: .hiverc @@ -197,7 +200,8 @@ private void processGlobalInitFile() { hivercFile = new File(hivercFile, SessionManager.HIVERCFILE); } if (hivercFile.isFile()) { - LOG.info("Running global init file: " + hivercFile); + LOG.info("Running global init file: {}", + MDC.of(LogKeys.GLOBAL_INIT_FILE$.MODULE$, hivercFile)); int rc = processor.processFile(hivercFile.getAbsolutePath()); if (rc != 0) { LOG.error("Failed on initializing global .hiverc file"); @@ -297,28 +301,29 @@ private static void setConf(String varname, String key, String varvalue, boolean @Override public void setOperationLogSessionDir(File operationLogRootDir) { if (!operationLogRootDir.exists()) { - LOG.warn("The operation log root directory is removed, recreating: " + - operationLogRootDir.getAbsolutePath()); + LOG.warn("The operation log root directory is removed, recreating: {}", + MDC.of(LogKeys.PATH$.MODULE$, operationLogRootDir.getAbsolutePath())); if (!operationLogRootDir.mkdirs()) { - LOG.warn("Unable to create operation log root directory: " + - operationLogRootDir.getAbsolutePath()); + LOG.warn("Unable to create operation log root directory: {}", + MDC.of(LogKeys.PATH$.MODULE$, operationLogRootDir.getAbsolutePath())); } } if (!operationLogRootDir.canWrite()) { - LOG.warn("The operation log root directory is not writable: " + - operationLogRootDir.getAbsolutePath()); + LOG.warn("The operation log root directory is not writable: {}", + MDC.of(LogKeys.PATH$.MODULE$, operationLogRootDir.getAbsolutePath())); } sessionLogDir = new File(operationLogRootDir, sessionHandle.getHandleIdentifier().toString()); isOperationLogEnabled = true; if (!sessionLogDir.exists()) { if (!sessionLogDir.mkdir()) { - LOG.warn("Unable to create operation log session directory: " + - sessionLogDir.getAbsolutePath()); + LOG.warn("Unable to create operation log session directory: {}", + MDC.of(LogKeys.PATH$.MODULE$, sessionLogDir.getAbsolutePath())); isOperationLogEnabled = false; } } if (isOperationLogEnabled) { - LOG.info("Operation log session directory is created: " + sessionLogDir.getAbsolutePath()); + LOG.info("Operation log session directory is created: {}", + MDC.of(LogKeys.PATH$.MODULE$, sessionLogDir.getAbsolutePath())); } } @@ -653,7 +658,8 @@ public void close() throws HiveSQLException { try { operationManager.closeOperation(opHandle); } catch (Exception e) { - LOG.warn("Exception is thrown closing operation " + opHandle, e); + LOG.warn("Exception is thrown closing operation {}", e, + MDC.of(LogKeys.OPERATION_HANDLE$.MODULE$, opHandle)); } } opHandleSet.clear(); @@ -693,13 +699,14 @@ private void cleanupPipeoutFile() { (dir, name) -> name.startsWith(sessionID) && name.endsWith(".pipeout")); if (fileAry == null) { - LOG.error("Unable to access pipeout files in " + lScratchDir); + LOG.error("Unable to access pipeout files in {}", + MDC.of(LogKeys.LOCAL_SCRATCH_DIR$.MODULE$, lScratchDir)); } else { for (File file : fileAry) { try { FileUtils.forceDelete(file); } catch (Exception e) { - LOG.error("Failed to cleanup pipeout file: " + file, e); + LOG.error("Failed to cleanup pipeout file: {}", e, MDC.of(LogKeys.PATH$.MODULE$, file)); } } } @@ -710,7 +717,8 @@ private void cleanupSessionLogDir() { try { FileUtils.forceDelete(sessionLogDir); } catch (Exception e) { - LOG.error("Failed to cleanup session log dir: " + sessionHandle, e); + LOG.error("Failed to cleanup session log dir: {}", e, + MDC.of(LogKeys.SESSION_HANDLE$.MODULE$, sessionHandle)); } } } @@ -759,7 +767,8 @@ private void closeTimedOutOperations(List operations) { try { operation.close(); } catch (Exception e) { - LOG.warn("Exception is thrown closing timed-out operation " + operation.getHandle(), e); + LOG.warn("Exception is thrown closing timed-out operation {}", e, + MDC.of(LogKeys.OPERATION_HANDLE$.MODULE$, operation.getHandle())); } } } finally { diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/SessionManager.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/SessionManager.java index fa342feacc7f4..6c282b679ca8c 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/SessionManager.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/SessionManager.java @@ -38,8 +38,11 @@ import org.apache.hive.service.rpc.thrift.TProtocolVersion; import org.apache.hive.service.server.HiveServer2; import org.apache.hive.service.server.ThreadFactoryWithGarbageCleanup; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; + +import org.apache.spark.internal.Logger; +import org.apache.spark.internal.LoggerFactory; +import org.apache.spark.internal.LogKeys; +import org.apache.spark.internal.MDC; /** * SessionManager. @@ -84,13 +87,15 @@ public synchronized void init(HiveConf hiveConf) { private void createBackgroundOperationPool() { int poolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS); - LOG.info("HiveServer2: Background operation thread pool size: " + poolSize); + LOG.info("HiveServer2: Background operation thread pool size: {}", + MDC.of(LogKeys.THREAD_POOL_SIZE$.MODULE$, poolSize)); int poolQueueSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_WAIT_QUEUE_SIZE); - LOG.info("HiveServer2: Background operation thread wait queue size: " + poolQueueSize); + LOG.info("HiveServer2: Background operation thread wait queue size: {}", + MDC.of(LogKeys.THREAD_POOL_WAIT_QUEUE_SIZE$.MODULE$, poolQueueSize)); long keepAliveTime = HiveConf.getTimeVar( hiveConf, ConfVars.HIVE_SERVER2_ASYNC_EXEC_KEEPALIVE_TIME, TimeUnit.SECONDS); - LOG.info( - "HiveServer2: Background operation thread keepalive time: " + keepAliveTime + " seconds"); + LOG.info("HiveServer2: Background operation thread keepalive time: {} ms", + MDC.of(LogKeys.THREAD_POOL_KEEPALIVE_TIME$.MODULE$, keepAliveTime * 1000L)); // Create a thread pool with #poolSize threads // Threads terminate when they are idle for more than the keepAliveTime @@ -115,26 +120,27 @@ private void initOperationLogRootDir() { isOperationLogEnabled = true; if (operationLogRootDir.exists() && !operationLogRootDir.isDirectory()) { - LOG.warn("The operation log root directory exists, but it is not a directory: " + - operationLogRootDir.getAbsolutePath()); + LOG.warn("The operation log root directory exists, but it is not a directory: {}", + MDC.of(LogKeys.PATH$.MODULE$, operationLogRootDir.getAbsolutePath())); isOperationLogEnabled = false; } if (!operationLogRootDir.exists()) { if (!operationLogRootDir.mkdirs()) { - LOG.warn("Unable to create operation log root directory: " + - operationLogRootDir.getAbsolutePath()); + LOG.warn("Unable to create operation log root directory: {}", + MDC.of(LogKeys.PATH$.MODULE$, operationLogRootDir.getAbsolutePath())); isOperationLogEnabled = false; } } if (isOperationLogEnabled) { - LOG.info("Operation log root directory is created: " + operationLogRootDir.getAbsolutePath()); + LOG.info("Operation log root directory is created: {}", + MDC.of(LogKeys.PATH$.MODULE$, operationLogRootDir.getAbsolutePath())); try { FileUtils.forceDeleteOnExit(operationLogRootDir); } catch (IOException e) { - LOG.warn("Failed to schedule cleanup HS2 operation logging root dir: " + - operationLogRootDir.getAbsolutePath(), e); + LOG.warn("Failed to schedule cleanup HS2 operation logging root dir: {}", e, + MDC.of(LogKeys.PATH$.MODULE$, operationLogRootDir.getAbsolutePath())); } } } @@ -164,12 +170,14 @@ public void run() { if (sessionTimeout > 0 && session.getLastAccessTime() + sessionTimeout <= current && (!checkOperation || session.getNoOperationTime() > sessionTimeout)) { SessionHandle handle = session.getSessionHandle(); - LOG.warn("Session " + handle + " is Timed-out (last access : " + - new Date(session.getLastAccessTime()) + ") and will be closed"); + LOG.warn("Session {} is Timed-out (last access : {}) and will be closed", + MDC.of(LogKeys.SESSION_HANDLE$.MODULE$, handle), + MDC.of(LogKeys.LAST_ACCESS_TIME$.MODULE$, new Date(session.getLastAccessTime()))); try { closeSession(handle); } catch (HiveSQLException e) { - LOG.warn("Exception is thrown closing session " + handle, e); + LOG.warn("Exception is thrown closing session {}", e, + MDC.of(LogKeys.SESSION_HANDLE$.MODULE$, handle)); } } else { session.closeExpiredOperations(); @@ -210,8 +218,9 @@ public synchronized void stop() { try { backgroundOperationPool.awaitTermination(timeout, TimeUnit.SECONDS); } catch (InterruptedException e) { - LOG.warn("HIVE_SERVER2_ASYNC_EXEC_SHUTDOWN_TIMEOUT = " + timeout + - " seconds has been exceeded. RUNNING background operations will be shut down", e); + LOG.warn("HIVE_SERVER2_ASYNC_EXEC_SHUTDOWN_TIMEOUT = {} ms has been exceeded. " + + "RUNNING background operations will be shut down", e, + MDC.of(LogKeys.TIMEOUT$.MODULE$, timeout * 1000)); } backgroundOperationPool = null; } @@ -223,8 +232,8 @@ private void cleanupLoggingRootDir() { try { FileUtils.forceDelete(operationLogRootDir); } catch (Exception e) { - LOG.warn("Failed to cleanup root dir of HS2 logging: " + operationLogRootDir - .getAbsolutePath(), e); + LOG.warn("Failed to cleanup root dir of HS2 logging: {}", e, + MDC.of(LogKeys.PATH$.MODULE$, operationLogRootDir.getAbsolutePath())); } } } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java index 4b18e2950a3de..752cd54af626b 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java @@ -41,8 +41,11 @@ import org.apache.thrift.server.ServerContext; import org.apache.thrift.server.TServerEventHandler; import org.apache.thrift.transport.TTransport; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; + +import org.apache.spark.internal.Logger; +import org.apache.spark.internal.LoggerFactory; +import org.apache.spark.internal.LogKeys; +import org.apache.spark.internal.MDC; /** * ThriftCLIService. @@ -50,7 +53,7 @@ */ public abstract class ThriftCLIService extends AbstractService implements TCLIService.Iface, Runnable { - public static final Logger LOG = LoggerFactory.getLogger(ThriftCLIService.class.getName()); + public static final Logger LOG = LoggerFactory.getLogger(ThriftCLIService.class); protected CLIService cliService; private static final TStatus OK_STATUS = new TStatus(TStatusCode.SUCCESS_STATUS); @@ -106,7 +109,7 @@ public void deleteContext(ServerContext serverContext, try { cliService.closeSession(sessionHandle); } catch (HiveSQLException e) { - LOG.warn("Failed to close session: " + e, e); + LOG.warn("Failed to close session: ", e); } } } @@ -236,7 +239,8 @@ private TStatus notSupportTokenErrorStatus() { @Override public TOpenSessionResp OpenSession(TOpenSessionReq req) throws TException { - LOG.info("Client protocol version: " + req.getClient_protocol()); + LOG.info("Client protocol version: {}", + MDC.of(LogKeys.PROTOCOL_VERSION$.MODULE$, req.getClient_protocol())); TOpenSessionResp resp = new TOpenSessionResp(); try { SessionHandle sessionHandle = getSessionHandle(req, resp); @@ -272,7 +276,7 @@ public TSetClientInfoResp SetClientInfo(TSetClientInfoReq req) throws TException sb.append(e.getKey()).append(" = ").append(e.getValue()); } if (sb != null) { - LOG.info("{}", sb); + LOG.info("{}", MDC.of(LogKeys.SET_CLIENT_INFO_REQUEST$.MODULE$, sb)); } } return new TSetClientInfoResp(OK_STATUS); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpServlet.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpServlet.java index b0bede741cb19..b423038fe2b61 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpServlet.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpServlet.java @@ -55,8 +55,11 @@ import org.ietf.jgss.GSSManager; import org.ietf.jgss.GSSName; import org.ietf.jgss.Oid; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; + +import org.apache.spark.internal.Logger; +import org.apache.spark.internal.LoggerFactory; +import org.apache.spark.internal.LogKeys; +import org.apache.spark.internal.MDC; /** * @@ -66,7 +69,7 @@ public class ThriftHttpServlet extends TServlet { private static final long serialVersionUID = 1L; - public static final Logger LOG = LoggerFactory.getLogger(ThriftHttpServlet.class.getName()); + public static final Logger LOG = LoggerFactory.getLogger(ThriftHttpServlet.class); private final String authType; private final UserGroupInformation serviceUGI; private final UserGroupInformation httpUGI; @@ -174,7 +177,8 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) } else { response.addCookie(hs2Cookie); } - LOG.info("Cookie added for clientUserName " + clientUserName); + LOG.info("Cookie added for clientUserName {}", + MDC.of(LogKeys.USER_NAME$.MODULE$, clientUserName)); } super.doPost(request, response); } @@ -228,7 +232,7 @@ private String getClientNameFromCookie(Cookie[] cookies) { String userName = HttpAuthUtils.getUserNameFromCookieToken(currValue); if (userName == null) { - LOG.warn("Invalid cookie token " + currValue); + LOG.warn("Invalid cookie token {}", MDC.of(LogKeys.TOKEN$.MODULE$, currValue)); continue; } //We have found a valid cookie in the client request. diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/HiveServer2.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/HiveServer2.java index ad5ca51b9e63d..b6c9b937c5f32 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/HiveServer2.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/HiveServer2.java @@ -36,9 +36,11 @@ import org.apache.hive.service.cli.thrift.ThriftBinaryCLIService; import org.apache.hive.service.cli.thrift.ThriftCLIService; import org.apache.hive.service.cli.thrift.ThriftHttpCLIService; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import org.apache.spark.internal.Logger; +import org.apache.spark.internal.LoggerFactory; +import org.apache.spark.internal.LogKeys; +import org.apache.spark.internal.MDC; import org.apache.spark.util.ShutdownHookManager; import org.apache.spark.util.SparkExitCode; @@ -142,8 +144,8 @@ private static void startHiveServer2() throws Throwable { if (++attempts >= maxAttempts) { throw new Error("Max start attempts " + maxAttempts + " exhausted", throwable); } else { - LOG.warn("Error starting HiveServer2 on attempt " + attempts - + ", will retry in 60 seconds", throwable); + LOG.warn("Error starting HiveServer2 on attempt {}, will retry in 60 seconds", + throwable, MDC.of(LogKeys.RETRY_COUNT$.MODULE$, attempts)); try { Thread.sleep(60L * 1000L); } catch (InterruptedException e) { @@ -159,7 +161,7 @@ public static void main(String[] args) { ServerOptionsProcessor oproc = new ServerOptionsProcessor("hiveserver2"); ServerOptionsProcessorResponse oprocResponse = oproc.parse(args); - HiveStringUtils.startupShutdownMessage(HiveServer2.class, args, LOG); + HiveStringUtils.startupShutdownMessage(HiveServer2.class, args, LOG.getSlf4jLogger()); // Call the executor which will execute the appropriate command based on the parsed options oprocResponse.getServerOptionsExecutor().execute(); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/ThreadWithGarbageCleanup.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/ThreadWithGarbageCleanup.java index afaa1403bfdcd..23957e146ddf1 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/ThreadWithGarbageCleanup.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/ThreadWithGarbageCleanup.java @@ -22,8 +22,9 @@ import org.apache.hadoop.hive.metastore.HiveMetaStore; import org.apache.hadoop.hive.metastore.RawStore; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; + +import org.apache.spark.internal.Logger; +import org.apache.spark.internal.LoggerFactory; /** * A HiveServer2 thread used to construct new server threads. diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index 7262bc22dc429..bf1c4978431b7 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -33,8 +33,8 @@ import org.apache.hive.service.Service.STATE import org.apache.hive.service.auth.HiveAuthFactory import org.apache.hive.service.cli._ import org.apache.hive.service.server.HiveServer2 -import org.slf4j.Logger +import org.apache.spark.internal.Logger import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.util.SQLKeywordUtils import org.apache.spark.sql.errors.QueryExecutionErrors From 6cc3dc2ef4d2ffbff7ffc400e723b97b462e1bab Mon Sep 17 00:00:00 2001 From: Vladimir Golubev Date: Thu, 9 May 2024 15:35:28 +0800 Subject: [PATCH 20/65] [SPARK-48169][SPARK-48143][SQL] Revert BadRecordException optimizations ### What changes were proposed in this pull request? Revert BadRecordException optimizations for UnivocityParser, StaxXmlParser and JacksonParser ### Why are the changes needed? To reduce the blast radius - this will be implemented differently. There were two PRs by me recently: - https://github.com/apache/spark/pull/46438 - https://github.com/apache/spark/pull/46400 which introduced optimizations to speed-up control flow between UnivocityParser, StaxXmlParser and JacksonParser. However, these changes are quite unstable and may break any calling code, which relies on exception cause type, for example. Also, there may be some Spark plugins/extensions using that exception for user-facing errors ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? N/A ### Was this patch authored or co-authored using generative AI tooling? No Closes #46478 from vladimirg-db/vladimirg-db/revert-SPARK-48169-SPARK-48143. Authored-by: Vladimir Golubev Signed-off-by: Wenchen Fan --- .../sql/catalyst/csv/UnivocityParser.scala | 8 +++---- .../sql/catalyst/json/JacksonParser.scala | 13 ++++++----- .../catalyst/util/BadRecordException.scala | 13 +++-------- .../sql/catalyst/util/FailureSafeParser.scala | 2 +- .../sql/catalyst/xml/StaxXmlParser.scala | 23 +++++++++---------- 5 files changed, 26 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index 8d06789a75126..a5158d8a22c6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -316,17 +316,17 @@ class UnivocityParser( throw BadRecordException( () => getCurrentInput, () => Array.empty, - () => QueryExecutionErrors.malformedCSVRecordError("")) + QueryExecutionErrors.malformedCSVRecordError("")) } val currentInput = getCurrentInput - var badRecordException: Option[() => Throwable] = if (tokens.length != parsedSchema.length) { + var badRecordException: Option[Throwable] = if (tokens.length != parsedSchema.length) { // If the number of tokens doesn't match the schema, we should treat it as a malformed record. // However, we still have chance to parse some of the tokens. It continues to parses the // tokens normally and sets null when `ArrayIndexOutOfBoundsException` occurs for missing // tokens. - Some(() => QueryExecutionErrors.malformedCSVRecordError(currentInput.toString)) + Some(QueryExecutionErrors.malformedCSVRecordError(currentInput.toString)) } else None // When the length of the returned tokens is identical to the length of the parsed schema, // we just need to: @@ -348,7 +348,7 @@ class UnivocityParser( } catch { case e: SparkUpgradeException => throw e case NonFatal(e) => - badRecordException = badRecordException.orElse(Some(() => e)) + badRecordException = badRecordException.orElse(Some(e)) // Use the corresponding DEFAULT value associated with the column, if any. row.update(i, ResolveDefaultColumns.existenceDefaultValues(requiredSchema)(i)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 848c20ee36bef..5e75ff6f6e1a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -613,7 +613,7 @@ class JacksonParser( // JSON parser currently doesn't support partial results for corrupted records. // For such records, all fields other than the field configured by // `columnNameOfCorruptRecord` are set to `null`. - throw BadRecordException(() => recordLiteral(record), cause = () => e) + throw BadRecordException(() => recordLiteral(record), () => Array.empty, e) case e: CharConversionException if options.encoding.isEmpty => val msg = """JSON parser cannot handle a character in its input. @@ -621,17 +621,18 @@ class JacksonParser( |""".stripMargin + e.getMessage val wrappedCharException = new CharConversionException(msg) wrappedCharException.initCause(e) - throw BadRecordException(() => recordLiteral(record), cause = () => wrappedCharException) + throw BadRecordException(() => recordLiteral(record), () => Array.empty, + wrappedCharException) case PartialResultException(row, cause) => throw BadRecordException( record = () => recordLiteral(record), partialResults = () => Array(row), - cause = () => convertCauseForPartialResult(cause)) + convertCauseForPartialResult(cause)) case PartialResultArrayException(rows, cause) => throw BadRecordException( record = () => recordLiteral(record), partialResults = () => rows, - cause = () => cause) + cause) // These exceptions should never be thrown outside of JacksonParser. // They are used for the control flow in the parser. We add them here for completeness // since they also indicate a bad record. @@ -639,12 +640,12 @@ class JacksonParser( throw BadRecordException( record = () => recordLiteral(record), partialResults = () => Array(InternalRow(arrayData)), - cause = () => convertCauseForPartialResult(cause)) + convertCauseForPartialResult(cause)) case PartialMapDataResultException(mapData, cause) => throw BadRecordException( record = () => recordLiteral(record), partialResults = () => Array(InternalRow(mapData)), - cause = () => convertCauseForPartialResult(cause)) + convertCauseForPartialResult(cause)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala index c4fcdf40360af..65a56c1064e45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala @@ -67,23 +67,16 @@ case class PartialResultArrayException( extends Exception(cause) /** - * Exception thrown when the underlying parser meets a bad record and can't parse it. Used for - * control flow between wrapper and underlying parser without overhead of creating a full exception. + * Exception thrown when the underlying parser meet a bad record and can't parse it. * @param record a function to return the record that cause the parser to fail * @param partialResults a function that returns an row array, which is the partial results of * parsing this bad record. - * @param cause a function to return the actual exception about why the record is bad and can't be - * parsed. + * @param cause the actual exception about why the record is bad and can't be parsed. */ case class BadRecordException( @transient record: () => UTF8String, @transient partialResults: () => Array[InternalRow] = () => Array.empty[InternalRow], - @transient cause: () => Throwable) - extends Exception() { - - override def getStackTrace(): Array[StackTraceElement] = new Array[StackTraceElement](0) - override def fillInStackTrace(): Throwable = this -} + cause: Throwable) extends Exception(cause) /** * Exception thrown when the underlying parser parses a JSON array as a struct. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala index b005563aa824f..10cd159c769b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala @@ -70,7 +70,7 @@ class FailureSafeParser[IN]( case DropMalformedMode => Iterator.empty case FailFastMode => - e.cause() match { + e.getCause match { case _: JsonArraysAsStructsException => // SPARK-42298 we recreate the exception here to make sure the error message // have the record content. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala index 2b237ab5db643..ab671e56a21e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala @@ -148,27 +148,26 @@ class StaxXmlParser( // XML parser currently doesn't support partial results for corrupted records. // For such records, all fields other than the field configured by // `columnNameOfCorruptRecord` are set to `null`. - throw BadRecordException(() => xmlRecord, cause = () => e) + throw BadRecordException(() => xmlRecord, () => Array.empty, e) case e: CharConversionException if options.charset.isEmpty => - throw BadRecordException(() => xmlRecord, cause = () => { - val msg = - """XML parser cannot handle a character in its input. - |Specifying encoding as an input option explicitly might help to resolve the issue. - |""".stripMargin + e.getMessage - val wrappedCharException = new CharConversionException(msg) - wrappedCharException.initCause(e) - wrappedCharException - }) + val msg = + """XML parser cannot handle a character in its input. + |Specifying encoding as an input option explicitly might help to resolve the issue. + |""".stripMargin + e.getMessage + val wrappedCharException = new CharConversionException(msg) + wrappedCharException.initCause(e) + throw BadRecordException(() => xmlRecord, () => Array.empty, + wrappedCharException) case PartialResultException(row, cause) => throw BadRecordException( record = () => xmlRecord, partialResults = () => Array(row), - () => cause) + cause) case PartialResultArrayException(rows, cause) => throw BadRecordException( record = () => xmlRecord, partialResults = () => rows, - () => cause) + cause) } } From a4ab82b8f340afa89b8865c92695d1fb102f974b Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Thu, 9 May 2024 16:01:07 +0800 Subject: [PATCH 21/65] [SPARK-48186][SQL] Add support for AbstractMapType ### What changes were proposed in this pull request? Addition of an abstract MapType (similar to abstract ArrayType in sql internal types) which accepts `StringTypeCollated` as `keyType` & `valueType`. Apart from extending this interface for all Spark functions, this PR also introduces collation awareness for json expression: schema_of_json. ### Why are the changes needed? This is needed in order to enable collation support for functions that use collated maps. ### Does this PR introduce _any_ user-facing change? Yes, users should now be able to use collated strings within arguments for json function: schema_of_json. ### How was this patch tested? E2e sql tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46458 from uros-db/abstract-map. Authored-by: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Signed-off-by: Wenchen Fan --- .../sql/internal/types/AbstractMapType.scala | 43 +++++++++++++++++++ .../sql/catalyst/expressions/ExprUtils.scala | 9 ++-- .../expressions/jsonExpressions.scala | 5 ++- .../sql/CollationSQLExpressionsSuite.scala | 34 +++++++++++++++ 4 files changed, 85 insertions(+), 6 deletions(-) create mode 100644 sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractMapType.scala diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractMapType.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractMapType.scala new file mode 100644 index 0000000000000..62f422f6f80a7 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractMapType.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal.types + +import org.apache.spark.sql.types.{AbstractDataType, DataType, MapType} + + +/** + * Use AbstractMapType(AbstractDataType, AbstractDataType) + * for defining expected types for expression parameters. + */ +case class AbstractMapType( + keyType: AbstractDataType, + valueType: AbstractDataType + ) extends AbstractDataType { + + override private[sql] def defaultConcreteType: DataType = + MapType(keyType.defaultConcreteType, valueType.defaultConcreteType, valueContainsNull = true) + + override private[sql] def acceptsType(other: DataType): Boolean = { + other.isInstanceOf[MapType] && + keyType.acceptsType(other.asInstanceOf[MapType].keyType) && + valueType.acceptsType(other.asInstanceOf[MapType].valueType) + } + + override private[spark] def simpleString: String = + s"map<${keyType.simpleString}, ${valueType.simpleString}>" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index 258bc0ed8fe73..fde2093460876 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors} +import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeAnyCollation} import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType, VariantType} import org.apache.spark.unsafe.types.UTF8String @@ -57,7 +58,7 @@ object ExprUtils extends QueryErrorsBase { def convertToMapData(exp: Expression): Map[String, String] = exp match { case m: CreateMap - if m.dataType.acceptsType(MapType(StringType, StringType, valueContainsNull = false)) => + if AbstractMapType(StringTypeAnyCollation, StringTypeAnyCollation).acceptsType(m.dataType) => val arrayMap = m.eval().asInstanceOf[ArrayBasedMapData] ArrayBasedMapData.toScalaMap(arrayMap).map { case (key, value) => key.toString -> value.toString @@ -77,7 +78,7 @@ object ExprUtils extends QueryErrorsBase { columnNameOfCorruptRecord: String): Unit = { schema.getFieldIndex(columnNameOfCorruptRecord).foreach { corruptFieldIndex => val f = schema(corruptFieldIndex) - if (f.dataType != StringType || !f.nullable) { + if (!f.dataType.isInstanceOf[StringType] || !f.nullable) { throw QueryCompilationErrors.invalidFieldTypeForCorruptRecordError() } } @@ -110,7 +111,7 @@ object ExprUtils extends QueryErrorsBase { */ def checkJsonSchema(schema: DataType): TypeCheckResult = { val isInvalid = schema.existsRecursively { - case MapType(keyType, _, _) if keyType != StringType => true + case MapType(keyType, _, _) if !keyType.isInstanceOf[StringType] => true case _ => false } if (isInvalid) { @@ -133,7 +134,7 @@ object ExprUtils extends QueryErrorsBase { def checkXmlSchema(schema: DataType): TypeCheckResult = { val isInvalid = schema.existsRecursively { // XML field names must be StringType - case MapType(keyType, _, _) if keyType != StringType => true + case MapType(keyType, _, _) if !keyType.isInstanceOf[StringType] => true case _ => false } if (isInvalid) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 8258bb389e2da..7005d663a3f96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -875,7 +875,7 @@ case class SchemaOfJson( child = child, options = ExprUtils.convertToMapData(options)) - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = false @@ -921,7 +921,8 @@ case class SchemaOfJson( .map(ArrayType(_, containsNull = at.containsNull)) .getOrElse(ArrayType(StructType(Nil), containsNull = at.containsNull)) case other: DataType => - jsonInferSchema.canonicalizeType(other, jsonOptions).getOrElse(StringType) + jsonInferSchema.canonicalizeType(other, jsonOptions).getOrElse( + SQLConf.get.defaultStringType) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index 19f34ec15aa07..530a77616c7c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -658,6 +658,40 @@ class CollationSQLExpressionsSuite }) } + test("Support SchemaOfJson json expression with collation") { + case class SchemaOfJsonTestCase( + input: String, + collationName: String, + result: Row + ) + + val testCases = Seq( + SchemaOfJsonTestCase("'[{\"col\":0}]'", + "UTF8_BINARY", Row("ARRAY>")), + SchemaOfJsonTestCase("'[{\"col\":01}]', map('allowNumericLeadingZeros', 'true')", + "UTF8_BINARY_LCASE", Row("ARRAY>")), + SchemaOfJsonTestCase("'[]'", + "UNICODE", Row("ARRAY")), + SchemaOfJsonTestCase("''", + "UNICODE_CI", Row("STRING")) + ) + + // Supported collations + testCases.foreach(t => { + val query = + s""" + |SELECT schema_of_json(${t.input}) + |""".stripMargin + // Result & data type + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + checkAnswer(testQuery, t.result) + val dataType = StringType(t.collationName) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + }) + } + test("Support StringToMap expression with collation") { // Supported collations case class StringToMapTestCase[R](t: String, p: String, k: String, c: String, result: R) From 91da4ac25148771b3656bc23b85fd2459ea0350a Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Thu, 9 May 2024 16:05:17 +0800 Subject: [PATCH 22/65] [SPARK-47354][SQL] Add collation support for variant expressions ### What changes were proposed in this pull request? Introduce collation awareness for variant expressions: parse_json, try_parse_json, is_variant_null, variant_get, try_variant_get, variant_explode, schema_of_variant, schema_of_variant_agg. ### Why are the changes needed? Add collation support for variant expressions in Spark. ### Does this PR introduce _any_ user-facing change? Yes, users should now be able to use collated strings within arguments for format functions: parse_json, try_parse_json, is_variant_null, variant_get, try_variant_get, variant_explode, schema_of_variant, schema_of_variant_agg. ### How was this patch tested? E2e sql tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46424 from uros-db/variant-expressions. Authored-by: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Signed-off-by: Wenchen Fan --- .../function_is_variant_null.explain | 2 +- .../function_parse_json.explain | 2 +- .../function_schema_of_variant.explain | 2 +- .../function_schema_of_variant_agg.explain | 2 +- .../function_try_parse_json.explain | 2 +- .../function_try_variant_get.explain | 2 +- .../function_variant_get.explain | 2 +- .../variant/variantExpressions.scala | 23 +- .../sql/CollationSQLExpressionsSuite.scala | 293 +++++++++++++++++- 9 files changed, 312 insertions(+), 18 deletions(-) diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_is_variant_null.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_is_variant_null.explain index 53ba167fca656..3c0b4fd87d9d2 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_is_variant_null.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_is_variant_null.explain @@ -1,2 +1,2 @@ -Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils$, BooleanType, isVariantNull, staticinvoke(class org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils$, VariantType, parseJson, g#0, true, StringType, BooleanType, true, false, true), VariantType, false, false, true) AS is_variant_null(parse_json(g))#0] +Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils$, BooleanType, isVariantNull, staticinvoke(class org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils$, VariantType, parseJson, g#0, true, StringTypeAnyCollation, BooleanType, true, false, true), VariantType, false, false, true) AS is_variant_null(parse_json(g))#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_parse_json.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_parse_json.explain index b844d19c85ac1..9ba74d04b02a4 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_parse_json.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_parse_json.explain @@ -1,2 +1,2 @@ -Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils$, VariantType, parseJson, g#0, true, StringType, BooleanType, true, false, true) AS parse_json(g)#0] +Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils$, VariantType, parseJson, g#0, true, StringTypeAnyCollation, BooleanType, true, false, true) AS parse_json(g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_variant.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_variant.explain index 62f8e7f3e6fea..d61db9f5394c5 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_variant.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_variant.explain @@ -1,2 +1,2 @@ -Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.variant.SchemaOfVariant$, StringType, schemaOfVariant, staticinvoke(class org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils$, VariantType, parseJson, g#0, true, StringType, BooleanType, true, false, true), VariantType, true, false, true) AS schema_of_variant(parse_json(g))#0] +Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.variant.SchemaOfVariant$, StringType, schemaOfVariant, staticinvoke(class org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils$, VariantType, parseJson, g#0, true, StringTypeAnyCollation, BooleanType, true, false, true), VariantType, true, false, true) AS schema_of_variant(parse_json(g))#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_variant_agg.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_variant_agg.explain index d4f9e2c66d99c..36f8920ce10cf 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_variant_agg.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_variant_agg.explain @@ -1,2 +1,2 @@ -Aggregate [schema_of_variant_agg(staticinvoke(class org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils$, VariantType, parseJson, g#0, true, StringType, BooleanType, true, false, true), 0, 0) AS schema_of_variant_agg(parse_json(g))#0] +Aggregate [schema_of_variant_agg(staticinvoke(class org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils$, VariantType, parseJson, g#0, true, StringTypeAnyCollation, BooleanType, true, false, true), 0, 0) AS schema_of_variant_agg(parse_json(g))#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_try_parse_json.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_try_parse_json.explain index 5c6b21a3ad46b..fda72dae1a747 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_try_parse_json.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_try_parse_json.explain @@ -1,2 +1,2 @@ -Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils$, VariantType, parseJson, g#0, false, StringType, BooleanType, true, true, true) AS try_parse_json(g)#0] +Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils$, VariantType, parseJson, g#0, false, StringTypeAnyCollation, BooleanType, true, true, true) AS try_parse_json(g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_try_variant_get.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_try_variant_get.explain index 748465142bde9..143bd113fd87f 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_try_variant_get.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_try_variant_get.explain @@ -1,2 +1,2 @@ -Project [try_variant_get(staticinvoke(class org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils$, VariantType, parseJson, g#0, true, StringType, BooleanType, true, false, true), $, IntegerType, false, Some(America/Los_Angeles)) AS try_variant_get(parse_json(g), $)#0] +Project [try_variant_get(staticinvoke(class org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils$, VariantType, parseJson, g#0, true, StringTypeAnyCollation, BooleanType, true, false, true), $, IntegerType, false, Some(America/Los_Angeles)) AS try_variant_get(parse_json(g), $)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_variant_get.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_variant_get.explain index 3503ee178ca71..f3af6fa9cf209 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_variant_get.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_variant_get.explain @@ -1,2 +1,2 @@ -Project [variant_get(staticinvoke(class org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils$, VariantType, parseJson, g#0, true, StringType, BooleanType, true, false, true), $, IntegerType, true, Some(America/Los_Angeles)) AS variant_get(parse_json(g), $)#0] +Project [variant_get(staticinvoke(class org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils$, VariantType, parseJson, g#0, true, StringTypeAnyCollation, BooleanType, true, false, true), $, IntegerType, true, Some(America/Los_Angeles)) AS variant_get(parse_json(g), $)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala index 5026d8e49ef16..2b8beacc45d36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala @@ -37,6 +37,8 @@ import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.types.StringTypeAnyCollation import org.apache.spark.sql.types._ import org.apache.spark.types.variant._ import org.apache.spark.types.variant.VariantUtil.Type @@ -61,7 +63,7 @@ case class ParseJson(child: Expression, failOnError: Boolean = true) inputTypes :+ BooleanType, returnNullable = !failOnError) - override def inputTypes: Seq[AbstractDataType] = StringType :: Nil + override def inputTypes: Seq[AbstractDataType] = StringTypeAnyCollation :: Nil override def dataType: DataType = VariantType @@ -199,7 +201,7 @@ case class VariantGet( final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET) - override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType) + override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringTypeAnyCollation) override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get" @@ -260,7 +262,7 @@ case object VariantGet { VariantType => true case ArrayType(elementType, _) => checkDataType(elementType) - case MapType(StringType, valueType, _) => checkDataType(valueType) + case MapType(_: StringType, valueType, _) => checkDataType(valueType) case StructType(fields) => fields.forall(f => checkDataType(f.dataType)) case _ => false } @@ -334,7 +336,8 @@ case object VariantGet { } case Type.BOOLEAN => Literal(v.getBoolean, BooleanType) case Type.LONG => Literal(v.getLong, LongType) - case Type.STRING => Literal(UTF8String.fromString(v.getString), StringType) + case Type.STRING => Literal(UTF8String.fromString(v.getString), + SQLConf.get.defaultStringType) case Type.DOUBLE => Literal(v.getDouble, DoubleType) case Type.DECIMAL => val d = Decimal(v.getDecimal) @@ -387,7 +390,7 @@ case object VariantGet { } else { invalidCast() } - case MapType(StringType, valueType, _) => + case MapType(_: StringType, valueType, _) => if (variantType == Type.OBJECT) { val size = v.objectSize() val keyArray = new Array[Any](size) @@ -568,7 +571,7 @@ case class VariantExplode(child: Expression) extends UnaryExpression with Genera override def elementSchema: StructType = { new StructType() .add("pos", IntegerType, nullable = false) - .add("key", StringType, nullable = true) + .add("key", SQLConf.get.defaultStringType, nullable = true) .add("value", VariantType, nullable = false) } } @@ -625,7 +628,7 @@ case class SchemaOfVariant(child: Expression) with ExpectsInputTypes { override lazy val replacement: Expression = StaticInvoke( SchemaOfVariant.getClass, - StringType, + SQLConf.get.defaultStringType, "schemaOfVariant", Seq(child), inputTypes, @@ -633,7 +636,7 @@ case class SchemaOfVariant(child: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(VariantType) - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def prettyName: String = "schema_of_variant" @@ -676,7 +679,7 @@ object SchemaOfVariant { case Type.NULL => NullType case Type.BOOLEAN => BooleanType case Type.LONG => LongType - case Type.STRING => StringType + case Type.STRING => SQLConf.get.defaultStringType case Type.DOUBLE => DoubleType case Type.DECIMAL => val d = v.getDecimal @@ -722,7 +725,7 @@ case class SchemaOfVariantAgg( override def inputTypes: Seq[AbstractDataType] = Seq(VariantType) - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index 530a77616c7c2..b5f1dc768bca0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import scala.collection.immutable.Seq -import org.apache.spark.SparkIllegalArgumentException +import org.apache.spark.{SparkException, SparkIllegalArgumentException, SparkRuntimeException} import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -757,6 +757,297 @@ class CollationSQLExpressionsSuite assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") } + test("Support ParseJson & TryParseJson variant expressions with collation") { + case class ParseJsonTestCase( + input: String, + collationName: String, + result: String + ) + + val testCases = Seq( + ParseJsonTestCase("{\"a\":1,\"b\":2}", "UTF8_BINARY", "{\"a\":1,\"b\":2}"), + ParseJsonTestCase("{\"A\":3,\"B\":4}", "UTF8_BINARY_LCASE", "{\"A\":3,\"B\":4}"), + ParseJsonTestCase("{\"c\":5,\"d\":6}", "UNICODE", "{\"c\":5,\"d\":6}"), + ParseJsonTestCase("{\"C\":7,\"D\":8}", "UNICODE_CI", "{\"C\":7,\"D\":8}") + ) + + // Supported collations (ParseJson) + testCases.foreach(t => { + val query = + s""" + |SELECT parse_json('${t.input}') + |""".stripMargin + // Result & data type + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + val testResult = testQuery.collect().map(_.toString()).mkString("") + assert(testResult === "[" + t.result + "]") // can't use checkAnswer for Variant + assert(testQuery.schema.fields.head.dataType.sameType(VariantType)) + } + }) + + // Supported collations (TryParseJson) + testCases.foreach(t => { + val query = + s""" + |SELECT try_parse_json('${t.input}') + |""".stripMargin + // Result & data type + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + val testResult = testQuery.collect().map(_.toString()).mkString("") + assert(testResult === "[" + t.result + "]") // can't use checkAnswer for Variant + assert(testQuery.schema.fields.head.dataType.sameType(VariantType)) + } + }) + } + + test("Handle invalid JSON for ParseJson variant expression with collation") { + // parse_json should throw an exception when the string is not valid JSON value + val json = "{\"a\":1," + val query = s"SELECT parse_json('$json');" + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") { + val e = intercept[SparkException] { + val testQuery = sql(query) + testQuery.collect() + } + assert(e.getErrorClass === "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION") + } + } + + test("Handle invalid JSON for TryParseJson variant expression with collation") { + // try_parse_json shouldn't throw an exception when the string is not valid JSON value + val json = "{\"a\":1,]" + val query = s"SELECT try_parse_json('$json');" + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") { + val testQuery = sql(query) + val testResult = testQuery.collect().map(_.toString()).mkString("") + assert(testResult === s"[null]") + } + } + + test("Support IsVariantNull variant expressions with collation") { + case class IsVariantNullTestCase( + input: String, + collationName: String, + result: Boolean + ) + + val testCases = Seq( + IsVariantNullTestCase("'null'", "UTF8_BINARY", result = true), + IsVariantNullTestCase("'\"null\"'", "UTF8_BINARY_LCASE", result = false), + IsVariantNullTestCase("'13'", "UNICODE", result = false), + IsVariantNullTestCase("null", "UNICODE_CI", result = false) + ) + + // Supported collations + testCases.foreach(t => { + val query = + s""" + |SELECT is_variant_null(parse_json(${t.input})) + |""".stripMargin + // Result & data type + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + checkAnswer(testQuery, Row(t.result)) + } + }) + } + + test("Support VariantGet & TryVariantGet variant expressions with collation") { + case class VariantGetTestCase( + input: String, + path: String, + variantType: String, + collationName: String, + result: Any, + resultType: DataType + ) + + val testCases = Seq( + VariantGetTestCase("{\"a\": 1}", "$.a", "int", "UTF8_BINARY", 1, IntegerType), + VariantGetTestCase("{\"a\": 1}", "$.b", "int", "UTF8_BINARY_LCASE", null, IntegerType), + VariantGetTestCase("[1, \"2\"]", "$[1]", "string", "UNICODE", "2", StringType("UNICODE")), + VariantGetTestCase("[1, \"2\"]", "$[2]", "string", "UNICODE_CI", null, + StringType("UNICODE_CI")) + ) + + // Supported collations (VariantGet) + testCases.foreach(t => { + val query = + s""" + |SELECT variant_get(parse_json('${t.input}'), '${t.path}', '${t.variantType}') + |""".stripMargin + // Result & data type + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + val testResult = testQuery.collect().map(_.toString()).mkString("") + assert(testResult === "[" + t.result + "]") // can't use checkAnswer for Variant + assert(testQuery.schema.fields.head.dataType.sameType(t.resultType)) + } + }) + + // Supported collations (TryVariantGet) + testCases.foreach(t => { + val query = + s""" + |SELECT try_variant_get(parse_json('${t.input}'), '${t.path}', '${t.variantType}') + |""".stripMargin + // Result & data type + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + val testResult = testQuery.collect().map(_.toString()).mkString("") + assert(testResult === "[" + t.result + "]") // can't use checkAnswer for Variant + assert(testQuery.schema.fields.head.dataType.sameType(t.resultType)) + } + }) + } + + test("Handle invalid JSON for VariantGet variant expression with collation") { + // variant_get should throw an exception if the cast fails + val json = "[1, \"Spark\"]" + val query = s"SELECT variant_get(parse_json('$json'), '$$[1]', 'int');" + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") { + val e = intercept[SparkRuntimeException] { + val testQuery = sql(query) + testQuery.collect() + } + assert(e.getErrorClass === "INVALID_VARIANT_CAST") + } + } + + test("Handle invalid JSON for TryVariantGet variant expression with collation") { + // try_variant_get shouldn't throw an exception if the cast fails + val json = "[1, \"Spark\"]" + val query = s"SELECT try_variant_get(parse_json('$json'), '$$[1]', 'int');" + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") { + val testQuery = sql(query) + val testResult = testQuery.collect().map(_.toString()).mkString("") + assert(testResult === s"[null]") + } + } + + test("Support VariantExplode variant expressions with collation") { + case class VariantExplodeTestCase( + input: String, + collationName: String, + result: String, + resultType: Seq[StructField] + ) + + val testCases = Seq( + VariantExplodeTestCase("[\"hello\", \"world\"]", "UTF8_BINARY", + Row(0, "null", "\"hello\"").toString() + Row(1, "null", "\"world\"").toString(), + Seq[StructField]( + StructField("pos", IntegerType, nullable = false), + StructField("key", StringType("UTF8_BINARY")), + StructField("value", VariantType, nullable = false) + ) + ), + VariantExplodeTestCase("[\"Spark\", \"SQL\"]", "UTF8_BINARY_LCASE", + Row(0, "null", "\"Spark\"").toString() + Row(1, "null", "\"SQL\"").toString(), + Seq[StructField]( + StructField("pos", IntegerType, nullable = false), + StructField("key", StringType("UTF8_BINARY_LCASE")), + StructField("value", VariantType, nullable = false) + ) + ), + VariantExplodeTestCase("{\"a\": true, \"b\": 3.14}", "UNICODE", + Row(0, "a", "true").toString() + Row(1, "b", "3.14").toString(), + Seq[StructField]( + StructField("pos", IntegerType, nullable = false), + StructField("key", StringType("UNICODE")), + StructField("value", VariantType, nullable = false) + ) + ), + VariantExplodeTestCase("{\"A\": 9.99, \"B\": false}", "UNICODE_CI", + Row(0, "A", "9.99").toString() + Row(1, "B", "false").toString(), + Seq[StructField]( + StructField("pos", IntegerType, nullable = false), + StructField("key", StringType("UNICODE_CI")), + StructField("value", VariantType, nullable = false) + ) + ) + ) + + // Supported collations + testCases.foreach(t => { + val query = + s""" + |SELECT * from variant_explode(parse_json('${t.input}')) + |""".stripMargin + // Result & data type + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + val testResult = testQuery.collect().map(_.toString()).mkString("") + assert(testResult === t.result) // can't use checkAnswer for Variant + assert(testQuery.schema.fields.sameElements(t.resultType)) + } + }) + } + + test("Support SchemaOfVariant variant expressions with collation") { + case class SchemaOfVariantTestCase( + input: String, + collationName: String, + result: String + ) + + val testCases = Seq( + SchemaOfVariantTestCase("null", "UTF8_BINARY", "VOID"), + SchemaOfVariantTestCase("[]", "UTF8_BINARY_LCASE", "ARRAY"), + SchemaOfVariantTestCase("[{\"a\":true,\"b\":0}]", "UNICODE", + "ARRAY>"), + SchemaOfVariantTestCase("[{\"A\":\"x\",\"B\":-1.00}]", "UNICODE_CI", + "ARRAY>") + ) + + // Supported collations + testCases.foreach(t => { + val query = + s""" + |SELECT schema_of_variant(parse_json('${t.input}')) + |""".stripMargin + // Result & data type + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + checkAnswer(testQuery, Row(t.result)) + assert(testQuery.schema.fields.head.dataType.sameType(StringType(t.collationName))) + } + }) + } + + test("Support SchemaOfVariantAgg variant expressions with collation") { + case class SchemaOfVariantAggTestCase( + input: String, + collationName: String, + result: String + ) + + val testCases = Seq( + SchemaOfVariantAggTestCase("('1'), ('2'), ('3')", "UTF8_BINARY", "BIGINT"), + SchemaOfVariantAggTestCase("('true'), ('false'), ('true')", "UTF8_BINARY_LCASE", "BOOLEAN"), + SchemaOfVariantAggTestCase("('{\"a\": 1}'), ('{\"b\": true}'), ('{\"c\": 1.23}')", + "UNICODE", "STRUCT"), + SchemaOfVariantAggTestCase("('{\"A\": \"x\"}'), ('{\"B\": 9.99}'), ('{\"C\": 0}')", + "UNICODE_CI", "STRUCT") + ) + + // Supported collations + testCases.foreach(t => { + val query = + s""" + |SELECT schema_of_variant_agg(parse_json(j)) FROM VALUES ${t.input} AS tab(j) + |""".stripMargin + // Result & data type + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + checkAnswer(testQuery, Row(t.result)) + assert(testQuery.schema.fields.head.dataType.sameType(StringType(t.collationName))) + } + }) + } + // TODO: Add more tests for other SQL expressions } From 045ec6a166c8d2bdf73585fc4160c136e5f2888a Mon Sep 17 00:00:00 2001 From: Anish Shrigondekar Date: Thu, 9 May 2024 17:10:01 +0900 Subject: [PATCH 23/65] [SPARK-48208][SS] Skip providing memory usage metrics from RocksDB if bounded memory usage is enabled ### What changes were proposed in this pull request? Skip providing memory usage metrics from RocksDB if bounded memory usage is enabled ### Why are the changes needed? Without this, we are providing memory usage that is the max usage per node at a partition level. For eg - if we report this ``` "allRemovalsTimeMs" : 93, "commitTimeMs" : 32240, "memoryUsedBytes" : 15956211724278, "numRowsDroppedByWatermark" : 0, "numShufflePartitions" : 200, "numStateStoreInstances" : 200, ``` We have 200 partitions in this case. So the memory usage per partition / state store would be ~78GB. However, this node has 256GB memory total and we have 2 such nodes. We have configured our cluster to use 30% of available memory on each node for RocksDB which is ~77GB. So the memory being reported here is actually per node rather than per partition which could be confusing for users. ### Does this PR introduce _any_ user-facing change? No - only a metrics reporting change ### How was this patch tested? Added unit tests ``` [info] Run completed in 10 seconds, 878 milliseconds. [info] Total number of tests run: 24 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 24, failed 0, canceled 0, ignored 0, pending 0 [info] All tests passed. ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #46491 from anishshri-db/task/SPARK-48208. Authored-by: Anish Shrigondekar Signed-off-by: Jungtaek Lim --- .../spark/sql/execution/streaming/state/RocksDB.scala | 11 ++++++++++- .../sql/execution/streaming/state/RocksDBSuite.scala | 11 +++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index caecf817c12f4..1516951922812 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -777,10 +777,19 @@ class RocksDB( .keys.filter(checkInternalColumnFamilies(_)).size val numExternalColFamilies = colFamilyNameToHandleMap.keys.size - numInternalColFamilies + // if bounded memory usage is enabled, we share the block cache across all state providers + // running on the same node and account the usage to this single cache. In this case, its not + // possible to provide partition level or query level memory usage. + val memoryUsage = if (conf.boundedMemoryUsage) { + 0L + } else { + readerMemUsage + memTableMemUsage + blockCacheUsage + } + RocksDBMetrics( numKeysOnLoadedVersion, numKeysOnWritingVersion, - readerMemUsage + memTableMemUsage + blockCacheUsage, + memoryUsage, pinnedBlocksMemUsage, totalSSTFilesBytes, nativeOpsLatencyMicros, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index ab2afa1b8a617..6086fd43846f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -1699,6 +1699,11 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared db.load(0) db.put("a", "1") db.commit() + if (boundedMemoryUsage == "true") { + assert(db.metricsOpt.get.totalMemUsageBytes === 0) + } else { + assert(db.metricsOpt.get.totalMemUsageBytes > 0) + } db.getWriteBufferManagerAndCache() } @@ -1709,6 +1714,11 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared db.load(0) db.put("a", "1") db.commit() + if (boundedMemoryUsage == "true") { + assert(db.metricsOpt.get.totalMemUsageBytes === 0) + } else { + assert(db.metricsOpt.get.totalMemUsageBytes > 0) + } db.getWriteBufferManagerAndCache() } @@ -1758,6 +1768,7 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared db.remove("a") db.put("c", "3") db.commit() + assert(db.metricsOpt.get.totalMemUsageBytes === 0) } } finally { RocksDBMemoryManager.resetWriteBufferManagerAndCache From 34ee0d8414b2f919a5f40e4e4b1ab4cfd033b696 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Thu, 9 May 2024 16:11:29 +0800 Subject: [PATCH 24/65] [SPARK-47421][SQL] Add collation support for URL expressions ### What changes were proposed in this pull request? Introduce collation awareness for URL expressions: url_encode, url_decode, parse_url. ### Why are the changes needed? Add collation support for URL expressions in Spark. ### Does this PR introduce _any_ user-facing change? Yes, users should now be able to use collated strings within arguments for URL functions: url_encode, url_decode, parse_url. ### How was this patch tested? E2e sql tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46460 from uros-db/url-expressions. Authored-by: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Signed-off-by: Wenchen Fan --- .../function_url_decode.explain | 2 +- .../function_url_encode.explain | 2 +- .../catalyst/expressions/urlExpressions.scala | 19 ++-- .../sql/CollationSQLExpressionsSuite.scala | 103 ++++++++++++++++++ 4 files changed, 115 insertions(+), 11 deletions(-) diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_url_decode.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_url_decode.explain index d612190396d2b..ee4936fec5374 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_url_decode.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_url_decode.explain @@ -1,2 +1,2 @@ -Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.UrlCodec$, StringType, decode, g#0, UTF-8, StringType, StringType, true, true, true) AS url_decode(g)#0] +Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.UrlCodec$, StringType, decode, g#0, UTF-8, StringTypeAnyCollation, StringTypeAnyCollation, true, true, true) AS url_decode(g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_url_encode.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_url_encode.explain index bd2c63e19c609..45c55f4f87375 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_url_encode.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_url_encode.explain @@ -1,2 +1,2 @@ -Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.UrlCodec$, StringType, encode, g#0, UTF-8, StringType, StringType, true, true, true) AS url_encode(g)#0] +Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.UrlCodec$, StringType, encode, g#0, UTF-8, StringTypeAnyCollation, StringTypeAnyCollation, true, true, true) AS url_encode(g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala index 47b37a5edeba8..ef8f2ea96eb0b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala @@ -28,7 +28,8 @@ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} +import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.types.{AbstractDataType, DataType} import org.apache.spark.unsafe.types.UTF8String // scalastyle:off line.size.limit @@ -54,16 +55,16 @@ case class UrlEncode(child: Expression) override def replacement: Expression = StaticInvoke( UrlCodec.getClass, - StringType, + SQLConf.get.defaultStringType, "encode", Seq(child, Literal("UTF-8")), - Seq(StringType, StringType)) + Seq(StringTypeAnyCollation, StringTypeAnyCollation)) override protected def withNewChildInternal(newChild: Expression): Expression = { copy(child = newChild) } - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) override def prettyName: String = "url_encode" } @@ -91,16 +92,16 @@ case class UrlDecode(child: Expression) override def replacement: Expression = StaticInvoke( UrlCodec.getClass, - StringType, + SQLConf.get.defaultStringType, "decode", Seq(child, Literal("UTF-8")), - Seq(StringType, StringType)) + Seq(StringTypeAnyCollation, StringTypeAnyCollation)) override protected def withNewChildInternal(newChild: Expression): Expression = { copy(child = newChild) } - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) override def prettyName: String = "url_decode" } @@ -154,8 +155,8 @@ case class ParseUrl(children: Seq[Expression], failOnError: Boolean = SQLConf.ge def this(children: Seq[Expression]) = this(children, SQLConf.get.ansiEnabled) override def nullable: Boolean = true - override def inputTypes: Seq[DataType] = Seq.fill(children.size)(StringType) - override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringTypeAnyCollation) + override def dataType: DataType = SQLConf.get.defaultStringType override def prettyName: String = "parse_url" // If the url is a constant, cache the URL object so that we don't need to convert url diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index b5f1dc768bca0..2b6390151bb9b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -208,6 +208,109 @@ class CollationSQLExpressionsSuite }) } + test("Support UrlEncode hash expression with collation") { + case class UrlEncodeTestCase( + input: String, + collationName: String, + result: String + ) + + val testCases = Seq( + UrlEncodeTestCase("https://spark.apache.org", "UTF8_BINARY", + "https%3A%2F%2Fspark.apache.org"), + UrlEncodeTestCase("https://spark.apache.org", "UTF8_BINARY_LCASE", + "https%3A%2F%2Fspark.apache.org"), + UrlEncodeTestCase("https://spark.apache.org", "UNICODE", + "https%3A%2F%2Fspark.apache.org"), + UrlEncodeTestCase("https://spark.apache.org", "UNICODE_CI", + "https%3A%2F%2Fspark.apache.org") + ) + + // Supported collations + testCases.foreach(t => { + val query = + s""" + |select url_encode('${t.input}') + |""".stripMargin + // Result + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + checkAnswer(testQuery, Row(t.result)) + val dataType = StringType(t.collationName) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + }) + } + + test("Support UrlDecode hash expression with collation") { + case class UrlDecodeTestCase( + input: String, + collationName: String, + result: String + ) + + val testCases = Seq( + UrlDecodeTestCase("https%3A%2F%2Fspark.apache.org", "UTF8_BINARY", + "https://spark.apache.org"), + UrlDecodeTestCase("https%3A%2F%2Fspark.apache.org", "UTF8_BINARY_LCASE", + "https://spark.apache.org"), + UrlDecodeTestCase("https%3A%2F%2Fspark.apache.org", "UNICODE", + "https://spark.apache.org"), + UrlDecodeTestCase("https%3A%2F%2Fspark.apache.org", "UNICODE_CI", + "https://spark.apache.org") + ) + + // Supported collations + testCases.foreach(t => { + val query = + s""" + |select url_decode('${t.input}') + |""".stripMargin + // Result + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + checkAnswer(testQuery, Row(t.result)) + val dataType = StringType(t.collationName) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + }) + } + + test("Support ParseUrl hash expression with collation") { + case class ParseUrlTestCase( + input: String, + collationName: String, + path: String, + result: String + ) + + val testCases = Seq( + ParseUrlTestCase("http://spark.apache.org/path?query=1", "UTF8_BINARY", "HOST", + "spark.apache.org"), + ParseUrlTestCase("http://spark.apache.org/path?query=2", "UTF8_BINARY_LCASE", "PATH", + "/path"), + ParseUrlTestCase("http://spark.apache.org/path?query=3", "UNICODE", "QUERY", + "query=3"), + ParseUrlTestCase("http://spark.apache.org/path?query=4", "UNICODE_CI", "PROTOCOL", + "http") + ) + + // Supported collations + testCases.foreach(t => { + val query = + s""" + |select parse_url('${t.input}', '${t.path}') + |""".stripMargin + // Result + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + checkAnswer(testQuery, Row(t.result)) + val dataType = StringType(t.collationName) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + }) + } + test("Conv expression with collation") { // Supported collations case class ConvTestCase( From 027327d94b3413ffb228ac482d51e75856c88d02 Mon Sep 17 00:00:00 2001 From: Niranjan Jayakar Date: Thu, 9 May 2024 17:23:31 +0900 Subject: [PATCH 25/65] [SPARK-47986][CONNECT][PYTHON] Unable to create a new session when the default session is closed by the server ### What changes were proposed in this pull request? This is a follow-up to a previous improvement - 7d04d0f0. In some cases, particularly when running older versions of the Spark cluster (3.5), the error actually manifests as a mismatch in the observed server-side session id between calls. With this fix, we also capture this case and ensure that this case is also handled. Further, we improve the implementation of `getActiveSession()` and introduce a similar `getDefaultSession()` that accounts for stopped sessions. This ensures that all places where default or active session is used, stopped sessions are considered neither default nor active. ### Why are the changes needed? Explained above. ### Does this PR introduce _any_ user-facing change? Previously, when client encounters a session mismatch, a user cannot create a new session. With this change, a user can call `getOrCreate()` on the SparkSession builder and create a new session. ### How was this patch tested? Attached unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46435 from nija-at/session-expires-part2. Lead-authored-by: Niranjan Jayakar Co-authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/connect/client/core.py | 1 + python/pyspark/sql/connect/session.py | 20 +++++++++---- .../sql/tests/connect/test_connect_session.py | 16 ++++++++++- .../pyspark/sql/tests/connect/test_session.py | 28 +++++++++++++++++++ 4 files changed, 59 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index badd9a33397ea..5e3462c2d0c1c 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -1829,6 +1829,7 @@ def _verify_response_integrity( response.server_side_session_id and response.server_side_session_id != self._server_session_id ): + self._closed = True raise PySparkAssertionError( "Received incorrect server side session identifier for request. " "Please create a new Spark Session to reconnect. (" diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index b688ca022c8c9..bec3c5b579a0c 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -237,9 +237,9 @@ def create(self) -> "SparkSession": def getOrCreate(self) -> "SparkSession": with SparkSession._lock: session = SparkSession.getActiveSession() - if session is None or session.is_stopped: - session = SparkSession._default_session - if session is None or session.is_stopped: + if session is None: + session = SparkSession._get_default_session() + if session is None: session = self.create() self._apply_options(session) return session @@ -285,9 +285,19 @@ def _set_default_and_active_session(cls, session: "SparkSession") -> None: if getattr(cls._active_session, "session", None) is None: cls._active_session.session = session + @classmethod + def _get_default_session(cls) -> Optional["SparkSession"]: + s = cls._default_session + if s is not None and not s.is_stopped: + return s + return None + @classmethod def getActiveSession(cls) -> Optional["SparkSession"]: - return getattr(cls._active_session, "session", None) + s = getattr(cls._active_session, "session", None) + if s is not None and not s.is_stopped: + return s + return None @classmethod def _getActiveSessionIfMatches(cls, session_id: str) -> "SparkSession": @@ -315,7 +325,7 @@ def _getActiveSessionIfMatches(cls, session_id: str) -> "SparkSession": def active(cls) -> "SparkSession": session = cls.getActiveSession() if session is None: - session = cls._default_session + session = cls._get_default_session() if session is None: raise PySparkRuntimeError( error_class="NO_ACTIVE_OR_DEFAULT_SESSION", diff --git a/python/pyspark/sql/tests/connect/test_connect_session.py b/python/pyspark/sql/tests/connect/test_connect_session.py index c5ce697a95612..1dd5cde0dff50 100644 --- a/python/pyspark/sql/tests/connect/test_connect_session.py +++ b/python/pyspark/sql/tests/connect/test_connect_session.py @@ -242,7 +242,7 @@ def toChannel(self): session = RemoteSparkSession.builder.channelBuilder(CustomChannelBuilder()).create() session.sql("select 1 + 1") - def test_reset_when_server_session_changes(self): + def test_reset_when_server_and_client_sessionids_mismatch(self): session = RemoteSparkSession.builder.remote("sc://localhost").getOrCreate() # run a simple query so the session id is synchronized. session.range(3).collect() @@ -256,6 +256,20 @@ def test_reset_when_server_session_changes(self): session = RemoteSparkSession.builder.remote("sc://localhost").getOrCreate() session.range(3).collect() + def test_reset_when_server_session_id_mismatch(self): + session = RemoteSparkSession.builder.remote("sc://localhost").getOrCreate() + # run a simple query so the session id is synchronized. + session.range(3).collect() + + # trigger a mismatch + session._client._server_session_id = str(uuid.uuid4()) + with self.assertRaises(SparkConnectException): + session.range(3).collect() + + # assert that getOrCreate() generates a new session + session = RemoteSparkSession.builder.remote("sc://localhost").getOrCreate() + session.range(3).collect() + class SparkConnectSessionWithOptionsTest(unittest.TestCase): def setUp(self) -> None: diff --git a/python/pyspark/sql/tests/connect/test_session.py b/python/pyspark/sql/tests/connect/test_session.py index 5184b9f061712..820f54b833275 100644 --- a/python/pyspark/sql/tests/connect/test_session.py +++ b/python/pyspark/sql/tests/connect/test_session.py @@ -77,6 +77,34 @@ def test_session_create_sets_active_session(self): self.assertIs(session, session2) session.stop() + def test_active_session_expires_when_client_closes(self): + s1 = RemoteSparkSession.builder.remote("sc://other").getOrCreate() + s2 = RemoteSparkSession.getActiveSession() + + self.assertIs(s1, s2) + + # We don't call close() to avoid executing ExecutePlanResponseReattachableIterator + s1._client._closed = True + + self.assertIsNone(RemoteSparkSession.getActiveSession()) + s3 = RemoteSparkSession.builder.remote("sc://other").getOrCreate() + + self.assertIsNot(s1, s3) + + def test_default_session_expires_when_client_closes(self): + s1 = RemoteSparkSession.builder.remote("sc://other").getOrCreate() + s2 = RemoteSparkSession.getDefaultSession() + + self.assertIs(s1, s2) + + # We don't call close() to avoid executing ExecutePlanResponseReattachableIterator + s1._client._closed = True + + self.assertIsNone(RemoteSparkSession.getDefaultSession()) + s3 = RemoteSparkSession.builder.remote("sc://other").getOrCreate() + + self.assertIsNot(s1, s3) + class JobCancellationTests(ReusedConnectTestCase): def test_tags(self): From ecca1bf6453e5e0042e1b56d4c35fb0b4d0f3121 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Thu, 9 May 2024 17:25:34 +0900 Subject: [PATCH 26/65] [SPARK-47365][PYTHON] Add toArrow() DataFrame method to PySpark ### What changes were proposed in this pull request? - Add a PySpark DataFrame method `toArrow()` which returns the contents of the DataFrame as a [PyArrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html), for both local Spark and Spark Connect. - Add a new entry to the **Apache Arrow in PySpark** user guide page describing usage of the `toArrow()` method. - Add a new option to the method `_collect_as_arrow()` to provide more useful output when there are zero records returned. (This keeps the implementation of `toArrow()` simpler.) ### Why are the changes needed? In the Apache Arrow community, we hear from a lot of users who want to return the contents of a PySpark DataFrame as a PyArrow Table. Currently the only documented way to do this is to return the contents as a pandas DataFrame, then use PyArrow (`pa`) to convert that to a PyArrow Table. ```py pa.Table.from_pandas(df.toPandas()) ``` But going through pandas adds significant overhead which is easily avoided since internally `toPandas()` already converts the contents of Spark DataFrame to Arrow format as an intermediate step when `spark.sql.execution.arrow.pyspark.enabled` is `true`. Currently it is also possible to use the experimental `_collect_as_arrow()` method to return the contents of a PySpark DataFrame as a list of PyArrow RecordBatches. This PR adds a new non-experimental method `toArrow()` which returns the more user-friendly PyArrow Table object. This PR also adds a new argument `empty_list_if_zero_records` to the experimental method `_collect_as_arrow()` to control what the method returns in the case when the result data has zero rows. If set to `True` (the default), the existing behavior is preserved, and the method returns an empty Python list. If set to `False`, the method returns returns a length-one list containing an empty Arrow RecordBatch which includes the schema. This is used by `toArrow()` which requires the schema even if the data has zero rows. For Spark Connect, there is already a `SparkSession.client.to_table()` method that returns a PyArrow table. This PR uses that to expose `toArrow()` for Spark Connect. ### Does this PR introduce _any_ user-facing change? - It adds a DataFrame method `toArrow()` to the PySpark SQL DataFrame API. - It adds a new argument `empty_list_if_zero_records` to the experimental DataFrame method `_collect_as_arrow()` with a default value which preserves the method's existing behavior. - It exposes `toArrow()` for Spark Connect, via the existing `SparkSession.client.to_table()` method. - It does not introduce any other user-facing changes. ### How was this patch tested? This adds a new test and a new helper function for the test in `pyspark/sql/tests/test_arrow.py`. ### Was this patch authored or co-authored using generative AI tooling? No Closes #45481 from ianmcook/SPARK-47365. Lead-authored-by: Ian Cook Co-authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- examples/src/main/python/sql/arrow.py | 18 +++++++ .../reference/pyspark.sql/dataframe.rst | 1 + .../source/user_guide/sql/arrow_pandas.rst | 49 +++++++++++++------ python/pyspark/sql/classic/dataframe.py | 4 ++ python/pyspark/sql/connect/dataframe.py | 4 ++ python/pyspark/sql/dataframe.py | 30 ++++++++++++ python/pyspark/sql/pandas/conversion.py | 48 ++++++++++++++++-- python/pyspark/sql/tests/test_arrow.py | 35 +++++++++++++ 8 files changed, 169 insertions(+), 20 deletions(-) diff --git a/examples/src/main/python/sql/arrow.py b/examples/src/main/python/sql/arrow.py index 03daf18eadbf3..48aee48d929c8 100644 --- a/examples/src/main/python/sql/arrow.py +++ b/examples/src/main/python/sql/arrow.py @@ -33,6 +33,22 @@ require_minimum_pyarrow_version() +def dataframe_to_arrow_table_example(spark: SparkSession) -> None: + import pyarrow as pa # noqa: F401 + from pyspark.sql.functions import rand + + # Create a Spark DataFrame + df = spark.range(100).drop("id").withColumns({"0": rand(), "1": rand(), "2": rand()}) + + # Convert the Spark DataFrame to a PyArrow Table + table = df.select("*").toArrow() + + print(table.schema) + # 0: double not null + # 1: double not null + # 2: double not null + + def dataframe_with_arrow_example(spark: SparkSession) -> None: import numpy as np import pandas as pd @@ -302,6 +318,8 @@ def arrow_slen(s): # type: ignore[no-untyped-def] .appName("Python Arrow-in-Spark example") \ .getOrCreate() + print("Running Arrow conversion example: DataFrame to Table") + dataframe_to_arrow_table_example(spark) print("Running Pandas to/from conversion example") dataframe_with_arrow_example(spark) print("Running pandas_udf example: Series to Frame") diff --git a/python/docs/source/reference/pyspark.sql/dataframe.rst b/python/docs/source/reference/pyspark.sql/dataframe.rst index b69a2771b04fc..ec39b645b1403 100644 --- a/python/docs/source/reference/pyspark.sql/dataframe.rst +++ b/python/docs/source/reference/pyspark.sql/dataframe.rst @@ -109,6 +109,7 @@ DataFrame DataFrame.tail DataFrame.take DataFrame.to + DataFrame.toArrow DataFrame.toDF DataFrame.toJSON DataFrame.toLocalIterator diff --git a/python/docs/source/user_guide/sql/arrow_pandas.rst b/python/docs/source/user_guide/sql/arrow_pandas.rst index 1d6a4df606906..0a527d832e211 100644 --- a/python/docs/source/user_guide/sql/arrow_pandas.rst +++ b/python/docs/source/user_guide/sql/arrow_pandas.rst @@ -39,6 +39,20 @@ is installed and available on all cluster nodes. You can install it using pip or conda from the conda-forge channel. See PyArrow `installation `_ for details. +Conversion to Arrow Table +------------------------- + +You can call :meth:`DataFrame.toArrow` to convert a Spark DataFrame to a PyArrow Table. + +.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py + :language: python + :lines: 37-49 + :dedent: 4 + +Note that :meth:`DataFrame.toArrow` results in the collection of all records in the DataFrame to +the driver program and should be done on a small subset of the data. Not all Spark data types are +currently supported and an error can be raised if a column has an unsupported type. + Enabling for Conversion to/from Pandas -------------------------------------- @@ -53,7 +67,7 @@ This can be controlled by ``spark.sql.execution.arrow.pyspark.fallback.enabled`` .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 37-52 + :lines: 53-68 :dedent: 4 Using the above optimizations with Arrow will produce the same results as when Arrow is not @@ -90,7 +104,7 @@ specify the type hints of ``pandas.Series`` and ``pandas.DataFrame`` as below: .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 56-80 + :lines: 72-96 :dedent: 4 In the following sections, it describes the combinations of the supported type hints. For simplicity, @@ -113,7 +127,7 @@ The following example shows how to create this Pandas UDF that computes the prod .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 84-114 + :lines: 100-130 :dedent: 4 For detailed usage, please see :func:`pandas_udf`. @@ -152,7 +166,7 @@ The following example shows how to create this Pandas UDF: .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 118-140 + :lines: 134-156 :dedent: 4 For detailed usage, please see :func:`pandas_udf`. @@ -174,7 +188,7 @@ The following example shows how to create this Pandas UDF: .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 144-167 + :lines: 160-183 :dedent: 4 For detailed usage, please see :func:`pandas_udf`. @@ -205,7 +219,7 @@ and window operations: .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 171-212 + :lines: 187-228 :dedent: 4 .. currentmodule:: pyspark.sql.functions @@ -270,7 +284,7 @@ in the group. .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 216-234 + :lines: 232-250 :dedent: 4 For detailed usage, please see please see :meth:`GroupedData.applyInPandas` @@ -288,7 +302,7 @@ The following example shows how to use :meth:`DataFrame.mapInPandas`: .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 238-249 + :lines: 254-265 :dedent: 4 For detailed usage, please see :meth:`DataFrame.mapInPandas`. @@ -327,7 +341,7 @@ The following example shows how to use ``DataFrame.groupby().cogroup().applyInPa .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 253-275 + :lines: 269-291 :dedent: 4 @@ -349,7 +363,7 @@ Here's an example that demonstrates the usage of both a default, pickled Python .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 279-297 + :lines: 295-313 :dedent: 4 Compared to the default, pickled Python UDFs, Arrow Python UDFs provide a more coherent type coercion mechanism. UDF @@ -421,9 +435,12 @@ be verified by the user. Setting Arrow ``self_destruct`` for memory savings ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Since Spark 3.2, the Spark configuration ``spark.sql.execution.arrow.pyspark.selfDestruct.enabled`` can be used to enable PyArrow's ``self_destruct`` feature, which can save memory when creating a Pandas DataFrame via ``toPandas`` by freeing Arrow-allocated memory while building the Pandas DataFrame. -This option is experimental, and some operations may fail on the resulting Pandas DataFrame due to immutable backing arrays. -Typically, you would see the error ``ValueError: buffer source array is read-only``. -Newer versions of Pandas may fix these errors by improving support for such cases. -You can work around this error by copying the column(s) beforehand. -Additionally, this conversion may be slower because it is single-threaded. +Since Spark 3.2, the Spark configuration ``spark.sql.execution.arrow.pyspark.selfDestruct.enabled`` +can be used to enable PyArrow's ``self_destruct`` feature, which can save memory when creating a +Pandas DataFrame via ``toPandas`` by freeing Arrow-allocated memory while building the Pandas +DataFrame. This option can also save memory when creating a PyArrow Table via ``toArrow``. +This option is experimental. When used with ``toPandas``, some operations may fail on the resulting +Pandas DataFrame due to immutable backing arrays. Typically, you would see the error +``ValueError: buffer source array is read-only``. Newer versions of Pandas may fix these errors by +improving support for such cases. You can work around this error by copying the column(s) +beforehand. Additionally, this conversion may be slower because it is single-threaded. diff --git a/python/pyspark/sql/classic/dataframe.py b/python/pyspark/sql/classic/dataframe.py index db9f22517ddad..9b6790d29aaa7 100644 --- a/python/pyspark/sql/classic/dataframe.py +++ b/python/pyspark/sql/classic/dataframe.py @@ -74,6 +74,7 @@ if TYPE_CHECKING: from py4j.java_gateway import JavaObject + import pyarrow as pa from pyspark.core.rdd import RDD from pyspark.core.context import SparkContext from pyspark._typing import PrimitiveType @@ -1825,6 +1826,9 @@ def mapInArrow( ) -> ParentDataFrame: return PandasMapOpsMixin.mapInArrow(self, func, schema, barrier, profile) + def toArrow(self) -> "pa.Table": + return PandasConversionMixin.toArrow(self) + def toPandas(self) -> "PandasDataFrameLike": return PandasConversionMixin.toPandas(self) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 843c92a9b27d2..3c9415adec2dd 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -1768,6 +1768,10 @@ def _to_table(self) -> Tuple["pa.Table", Optional[StructType]]: assert table is not None return (table, schema) + def toArrow(self) -> "pa.Table": + table, _ = self._to_table() + return table + def toPandas(self) -> "PandasDataFrameLike": query = self._plan.to_proto(self._session.client) return self._session.client.to_pandas(query, self._plan.observations) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index e3d52c45d0c1d..886f72cc371e9 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -44,6 +44,7 @@ if TYPE_CHECKING: from py4j.java_gateway import JavaObject + import pyarrow as pa from pyspark.core.context import SparkContext from pyspark.core.rdd import RDD from pyspark._typing import PrimitiveType @@ -1200,6 +1201,7 @@ def collect(self) -> List[Row]: DataFrame.take : Returns the first `n` rows. DataFrame.head : Returns the first `n` rows. DataFrame.toPandas : Returns the data as a pandas DataFrame. + DataFrame.toArrow : Returns the data as a PyArrow Table. Notes ----- @@ -6213,6 +6215,34 @@ def mapInArrow( """ ... + @dispatch_df_method + def toArrow(self) -> "pa.Table": + """ + Returns the contents of this :class:`DataFrame` as PyArrow ``pyarrow.Table``. + + This is only available if PyArrow is installed and available. + + .. versionadded:: 4.0.0 + + Notes + ----- + This method should only be used if the resulting PyArrow ``pyarrow.Table`` is + expected to be small, as all the data is loaded into the driver's memory. + + This API is a developer API. + + Examples + -------- + >>> df.toArrow() # doctest: +SKIP + pyarrow.Table + age: int64 + name: string + ---- + age: [[2,5]] + name: [["Alice","Bob"]] + """ + ... + def toPandas(self) -> "PandasDataFrameLike": """ Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py index ec4e21daba97b..344608317beb7 100644 --- a/python/pyspark/sql/pandas/conversion.py +++ b/python/pyspark/sql/pandas/conversion.py @@ -225,15 +225,48 @@ def toPandas(self) -> "PandasDataFrameLike": else: return pdf - def _collect_as_arrow(self, split_batches: bool = False) -> List["pa.RecordBatch"]: + def toArrow(self) -> "pa.Table": + from pyspark.sql.dataframe import DataFrame + + assert isinstance(self, DataFrame) + + jconf = self.sparkSession._jconf + + from pyspark.sql.pandas.types import to_arrow_schema + from pyspark.sql.pandas.utils import require_minimum_pyarrow_version + + require_minimum_pyarrow_version() + to_arrow_schema(self.schema) + + import pyarrow as pa + + self_destruct = jconf.arrowPySparkSelfDestructEnabled() + batches = self._collect_as_arrow( + split_batches=self_destruct, empty_list_if_zero_records=False + ) + table = pa.Table.from_batches(batches) + # Ensure only the table has a reference to the batches, so that + # self_destruct (if enabled) is effective + del batches + return table + + def _collect_as_arrow( + self, + split_batches: bool = False, + empty_list_if_zero_records: bool = True, + ) -> List["pa.RecordBatch"]: """ - Returns all records as a list of ArrowRecordBatches, pyarrow must be installed + Returns all records as a list of Arrow RecordBatches. PyArrow must be installed and available on driver and worker Python environments. This is an experimental feature. :param split_batches: split batches such that each column is in its own allocation, so that the selfDestruct optimization is effective; default False. + :param empty_list_if_zero_records: If True (the default), returns an empty list if the + result has 0 records. Otherwise, returns a list of length 1 containing an empty + Arrow RecordBatch which includes the schema. + .. note:: Experimental. """ from pyspark.sql.dataframe import DataFrame @@ -282,8 +315,15 @@ def _collect_as_arrow(self, split_batches: bool = False) -> List["pa.RecordBatch batches = results[:-1] batch_order = results[-1] - # Re-order the batch list using the correct order - return [batches[i] for i in batch_order] + if len(batches) or empty_list_if_zero_records: + # Re-order the batch list using the correct order + return [batches[i] for i in batch_order] + else: + from pyspark.sql.pandas.types import to_arrow_schema + + schema = to_arrow_schema(self.schema) + empty_arrays = [pa.array([], type=field.type) for field in schema] + return [pa.RecordBatch.from_arrays(empty_arrays, schema=schema)] class SparkConversionMixin: diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 8636e953aaf8f..71d3c46e5ee1e 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -179,6 +179,35 @@ def create_pandas_data_frame(self): data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) return pd.DataFrame(data=data_dict) + def create_arrow_table(self): + import pyarrow as pa + import pyarrow.compute as pc + + data_dict = {} + for j, name in enumerate(self.schema.names): + data_dict[name] = [self.data[i][j] for i in range(len(self.data))] + t = pa.Table.from_pydict(data_dict) + # convert these to Arrow types + new_schema = t.schema.set( + t.schema.get_field_index("2_int_t"), pa.field("2_int_t", pa.int32()) + ) + new_schema = new_schema.set( + new_schema.get_field_index("4_float_t"), pa.field("4_float_t", pa.float32()) + ) + new_schema = new_schema.set( + new_schema.get_field_index("6_decimal_t"), + pa.field("6_decimal_t", pa.decimal128(38, 18)), + ) + t = t.cast(new_schema) + # convert timestamp to local timezone + timezone = self.spark.conf.get("spark.sql.session.timeZone") + t = t.set_column( + t.schema.get_field_index("8_timestamp_t"), + "8_timestamp_t", + pc.assume_timezone(t["8_timestamp_t"], timezone), + ) + return t + @property def create_np_arrs(self): import numpy as np @@ -339,6 +368,12 @@ def test_pandas_round_trip(self): pdf_arrow = df.toPandas() assert_frame_equal(pdf_arrow, pdf) + def test_arrow_round_trip(self): + t_in = self.create_arrow_table() + df = self.spark.createDataFrame(self.data, schema=self.schema) + t_out = df.toArrow() + self.assertTrue(t_out.equals(t_in)) + def test_pandas_self_destruct(self): import pyarrow as pa From 9e62dbad6a75a75c68d28592c99a9e94ef74fbec Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 9 May 2024 17:29:52 +0900 Subject: [PATCH 27/65] [SPARK-48212][PYTHON][CONNECT][TESTS] Fully enable `PandasUDFParityTests.test_udf_wrong_arg` ### What changes were proposed in this pull request? Fully enable `PandasUDFParityTests.test_udf_wrong_arg` it was partially enabled before (only `check_udf_wrong_arg` in `test_udf_wrong_arg`): https://github.com/apache/spark/blob/678aeb7ef7086bd962df7ac6d1c5f39151a0515b/python/pyspark/sql/tests/pandas/test_pandas_udf.py#L127-L157 ### Why are the changes needed? test coverage ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #46498 from zhengruifeng/enable_test_udf_wrong_arg. Lead-authored-by: Ruifeng Zheng Co-authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/tests/connect/test_parity_pandas_udf.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py b/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py index 7f280a009f781..364e41716474b 100644 --- a/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py +++ b/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py @@ -20,8 +20,7 @@ class PandasUDFParityTests(PandasUDFTestsMixin, ReusedConnectTestCase): - def test_udf_wrong_arg(self): - self.check_udf_wrong_arg() + pass if __name__ == "__main__": From 207d675110e6fa699a434e81296f6f050eb0304b Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 9 May 2024 17:27:04 +0800 Subject: [PATCH 28/65] [SPARK-48211][SQL] DB2: Read SMALLINT as ShortType ### What changes were proposed in this pull request? This PR supports read SMALLINT from DB2 as ShortType ### Why are the changes needed? - 15 bits is sufficient - we write ShortType as SMALLINT - we read smallint from other builtin jdbc sources as ShortType ### Does this PR introduce _any_ user-facing change? yes, we add a migration guide for this ### How was this patch tested? changed tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #46497 from yaooqinn/SPARK-48211. Authored-by: Kent Yao Signed-off-by: Kent Yao --- .../spark/sql/jdbc/DB2IntegrationSuite.scala | 69 +++++++++++-------- docs/sql-migration-guide.md | 1 + .../apache/spark/sql/internal/SQLConf.scala | 11 +++ .../apache/spark/sql/jdbc/DB2Dialect.scala | 3 + 4 files changed, 56 insertions(+), 28 deletions(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala index cedb33d491fbc..aca174cce1949 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, ByteType, ShortType, StructType} import org.apache.spark.tags.DockerTest @@ -77,32 +78,44 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite { } test("Numeric types") { - val df = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties) - val rows = df.collect() - assert(rows.length == 1) - val types = rows(0).toSeq.map(x => x.getClass.toString) - assert(types.length == 10) - assert(types(0).equals("class java.lang.Integer")) - assert(types(1).equals("class java.lang.Integer")) - assert(types(2).equals("class java.lang.Long")) - assert(types(3).equals("class java.math.BigDecimal")) - assert(types(4).equals("class java.lang.Double")) - assert(types(5).equals("class java.lang.Double")) - assert(types(6).equals("class java.lang.Float")) - assert(types(7).equals("class java.math.BigDecimal")) - assert(types(8).equals("class java.math.BigDecimal")) - assert(types(9).equals("class java.math.BigDecimal")) - assert(rows(0).getInt(0) == 17) - assert(rows(0).getInt(1) == 77777) - assert(rows(0).getLong(2) == 922337203685477580L) - val bd = new BigDecimal("123456745.56789012345000000000") - assert(rows(0).getAs[BigDecimal](3).equals(bd)) - assert(rows(0).getDouble(4) == 42.75) - assert(rows(0).getDouble(5) == 5.4E-70) - assert(rows(0).getFloat(6) == 3.4028234663852886e+38) - assert(rows(0).getDecimal(7) == new BigDecimal("4.299900000000000000")) - assert(rows(0).getDecimal(8) == new BigDecimal("99999999999999990000.000000000000000000")) - assert(rows(0).getDecimal(9) == new BigDecimal("1234567891234567.123456789123456789")) + Seq(true, false).foreach { legacy => + withSQLConf(SQLConf.LEGACY_DB2_TIMESTAMP_MAPPING_ENABLED.key -> legacy.toString) { + val df = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 10) + if (legacy) { + assert(types(0).equals("class java.lang.Integer")) + } else { + assert(types(0).equals("class java.lang.Short")) + } + assert(types(1).equals("class java.lang.Integer")) + assert(types(2).equals("class java.lang.Long")) + assert(types(3).equals("class java.math.BigDecimal")) + assert(types(4).equals("class java.lang.Double")) + assert(types(5).equals("class java.lang.Double")) + assert(types(6).equals("class java.lang.Float")) + assert(types(7).equals("class java.math.BigDecimal")) + assert(types(8).equals("class java.math.BigDecimal")) + assert(types(9).equals("class java.math.BigDecimal")) + if (legacy) { + assert(rows(0).getInt(0) == 17) + } else { + assert(rows(0).getShort(0) == 17) + } + assert(rows(0).getInt(1) == 77777) + assert(rows(0).getLong(2) == 922337203685477580L) + val bd = new BigDecimal("123456745.56789012345000000000") + assert(rows(0).getAs[BigDecimal](3).equals(bd)) + assert(rows(0).getDouble(4) == 42.75) + assert(rows(0).getDouble(5) == 5.4E-70) + assert(rows(0).getFloat(6) == 3.4028234663852886e+38) + assert(rows(0).getDecimal(7) == new BigDecimal("4.299900000000000000")) + assert(rows(0).getDecimal(8) == new BigDecimal("99999999999999990000.000000000000000000")) + assert(rows(0).getDecimal(9) == new BigDecimal("1234567891234567.123456789123456789")) + } + } } test("Date types") { @@ -154,8 +167,8 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite { new StructType().add("c1", ShortType).add("b", ByteType).add("c3", BooleanType)) df4.write.jdbc(jdbcUrl, "otherscopy", new Properties) val rows = sqlContext.read.jdbc(jdbcUrl, "otherscopy", new Properties).collect() - assert(rows(0).getInt(0) == 1) - assert(rows(0).getInt(1) == 20) + assert(rows(0).getShort(0) == 1) + assert(rows(0).getShort(1) == 20) assert(rows(0).getString(2) == "1") } diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index bd6604cb69c0f..8b55fb48b8b57 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -50,6 +50,7 @@ license: | - Since Spark 4.0, Oracle JDBC datasource will write TimestampType as TIMESTAMP WITH LOCAL TIME ZONE, while in Spark 3.5 and previous, write as TIMESTAMP. To restore the previous behavior, set `spark.sql.legacy.oracle.timestampMapping.enabled` to `true`. - Since Spark 4.0, MsSQL Server JDBC datasource will read TINYINT as ShortType, while in Spark 3.5 and previous, read as IntegerType. To restore the previous behavior, set `spark.sql.legacy.mssqlserver.numericMapping.enabled` to `true`. - Since Spark 4.0, MsSQL Server JDBC datasource will read DATETIMEOFFSET as TimestampType, while in Spark 3.5 and previous, read as StringType. To restore the previous behavior, set `spark.sql.legacy.mssqlserver.datetimeoffsetMapping.enabled` to `true`. +- Since Spark 4.0, DB2 JDBC datasource will read SMALLINT as ShortType, while in Spark 3.5 and previous, it was read as IntegerType. To restore the previous behavior, set `spark.sql.legacy.db2.numericMapping.enabled` to `true`. - Since Spark 4.0, The default value for `spark.sql.legacy.ctePrecedencePolicy` has been changed from `EXCEPTION` to `CORRECTED`. Instead of raising an error, inner CTE definitions take precedence over outer definitions. - Since Spark 4.0, The default value for `spark.sql.legacy.timeParserPolicy` has been changed from `EXCEPTION` to `CORRECTED`. Instead of raising an `INCONSISTENT_BEHAVIOR_CROSS_VERSION` error, `CANNOT_PARSE_TIMESTAMP` will be raised if ANSI mode is enable. `NULL` will be returned if ANSI mode is disabled. See [Datetime Patterns for Formatting and Parsing](sql-ref-datetime-pattern.html). - Since Spark 4.0, A bug falsely allowing `!` instead of `NOT` when `!` is not a prefix operator has been fixed. Clauses such as `expr ! IN (...)`, `expr ! BETWEEN ...`, or `col ! NULL` now raise syntax errors. To restore the previous behavior, set `spark.sql.legacy.bangEqualsNot` to `true`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index df75985043d0d..54aa87260534f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4222,6 +4222,14 @@ object SQLConf { .booleanConf .createWithDefault(false) + val LEGACY_DB2_TIMESTAMP_MAPPING_ENABLED = + buildConf("spark.sql.legacy.db2.numericMapping.enabled") + .internal() + .doc("When true, SMALLINT maps to IntegerType in DB2; otherwise, ShortType" ) + .version("4.0.0") + .booleanConf + .createWithDefault(false) + val CSV_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.csv.filterPushdown.enabled") .doc("When true, enable filter pushdown to CSV datasource.") .version("3.0.0") @@ -5339,6 +5347,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def legacyOracleTimestampMappingEnabled: Boolean = getConf(LEGACY_ORACLE_TIMESTAMP_MAPPING_ENABLED) + def legacyDB2numericMappingEnabled: Boolean = + getConf(LEGACY_DB2_TIMESTAMP_MAPPING_ENABLED) + override def legacyTimeParserPolicy: LegacyBehaviorPolicy.Value = { LegacyBehaviorPolicy.withName(getConf(SQLConf.LEGACY_TIME_PARSER_POLICY)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala index 31a7c783ba60e..cc596a5f0185e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.expressions.Expression import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ private case class DB2Dialect() extends JdbcDialect { @@ -86,6 +87,8 @@ private case class DB2Dialect() extends JdbcDialect { typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = sqlType match { + case Types.SMALLINT if !SQLConf.get.legacyDB2numericMappingEnabled => + Option(ShortType) case Types.REAL => Option(FloatType) case Types.OTHER => typeName match { From 3fd38d4c07f6c998ec8bb234796f83a6aecfc0d2 Mon Sep 17 00:00:00 2001 From: Chenhao Li Date: Thu, 9 May 2024 22:45:10 +0800 Subject: [PATCH 29/65] [SPARK-47803][FOLLOWUP] Check nulls when casting nested type to variant ### What changes were proposed in this pull request? It adds null checks when accessing a nested element when casting a nested type to variant. It is necessary because the `get` API doesn't guarantee to return null when the slot is null. For example, `ColumnarArray.get` may return the default value of a primitive type if the slot is null. ### Why are the changes needed? It is a bug fix is necessary for the cast-to-variant expression to work correctly. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Two new unit tests. One directly uses `ColumnarArray` as the input of the cast. The other creates a real-world situation where `ColumnarArray` is the input of the cast (scan). Both of them would fail without the code change in this PR. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46486 from chenhao-db/fix_cast_nested_to_variant. Authored-by: Chenhao Li Signed-off-by: Wenchen Fan --- .../variant/VariantExpressionEvalUtils.scala | 9 +++-- .../spark/sql/VariantEndToEndSuite.scala | 33 +++++++++++++++++-- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala index eb235eb854e09..f7f7097173bb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala @@ -103,7 +103,8 @@ object VariantExpressionEvalUtils { val offsets = new java.util.ArrayList[java.lang.Integer](data.numElements()) for (i <- 0 until data.numElements()) { offsets.add(builder.getWritePos - start) - buildVariant(builder, data.get(i, elementType), elementType) + val element = if (data.isNullAt(i)) null else data.get(i, elementType) + buildVariant(builder, element, elementType) } builder.finishWritingArray(start, offsets) case MapType(StringType, valueType, _) => @@ -116,7 +117,8 @@ object VariantExpressionEvalUtils { val key = keys.getUTF8String(i).toString val id = builder.addKey(key) fields.add(new VariantBuilder.FieldEntry(key, id, builder.getWritePos - start)) - buildVariant(builder, values.get(i, valueType), valueType) + val value = if (values.isNullAt(i)) null else values.get(i, valueType) + buildVariant(builder, value, valueType) } builder.finishWritingObject(start, fields) case StructType(structFields) => @@ -127,7 +129,8 @@ object VariantExpressionEvalUtils { val key = structFields(i).name val id = builder.addKey(key) fields.add(new VariantBuilder.FieldEntry(key, id, builder.getWritePos - start)) - buildVariant(builder, data.get(i, structFields(i).dataType), structFields(i).dataType) + val value = if (data.isNullAt(i)) null else data.get(i, structFields(i).dataType) + buildVariant(builder, value, structFields(i).dataType) } builder.finishWritingObject(start, fields) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala index 3964bf3aedece..53be9d50d351e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala @@ -16,11 +16,13 @@ */ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.expressions.{CreateArray, CreateNamedStruct, JsonToStructs, Literal, StructsToJson} +import org.apache.spark.sql.catalyst.expressions.{Cast, CreateArray, CreateNamedStruct, JsonToStructs, Literal, StructsToJson} import org.apache.spark.sql.catalyst.expressions.variant.ParseJson import org.apache.spark.sql.execution.WholeStageCodegenExec +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.VariantType +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarArray import org.apache.spark.types.variant.VariantBuilder import org.apache.spark.unsafe.types.VariantVal @@ -250,4 +252,31 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { Seq.fill(3)(Row("STRUCT>")) ++ Seq(Row("STRUCT>"))) } } + + test("cast to variant with ColumnarArray input") { + val dataVector = new OnHeapColumnVector(4, LongType) + dataVector.appendNull() + dataVector.appendLong(123) + dataVector.appendNull() + dataVector.appendLong(456) + val array = new ColumnarArray(dataVector, 0, 4) + val variant = Cast(Literal(array, ArrayType(LongType)), VariantType).eval() + assert(variant.toString == "[null,123,null,456]") + dataVector.close() + } + + test("cast to variant with scan input") { + withTempPath { dir => + val path = dir.getAbsolutePath + val input = Seq(Row(Array(1, null), Map("k1" -> null, "k2" -> false), Row(null, "str"))) + val schema = StructType.fromDDL( + "a array, m map, s struct") + spark.createDataFrame(spark.sparkContext.parallelize(input), schema).write.parquet(path) + val df = spark.read.parquet(path).selectExpr( + s"cast(cast(a as variant) as ${schema(0).dataType.sql})", + s"cast(cast(m as variant) as ${schema(1).dataType.sql})", + s"cast(cast(s as variant) as ${schema(2).dataType.sql})") + checkAnswer(df, input) + } + } } From 21333f8c1fc01756e6708ad6ccf21f585fcb881d Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Thu, 9 May 2024 23:05:20 +0800 Subject: [PATCH 30/65] [SPARK-47409][SQL] Add support for collation for StringTrim type of functions/expressions (for UTF8_BINARY & LCASE) Recreating [original PR](https://github.com/apache/spark/pull/45749) because code has been reorganized in [this PR](https://github.com/apache/spark/pull/45978). ### What changes were proposed in this pull request? This PR is created to add support for collations to StringTrim family of functions/expressions, specifically: - `StringTrim` - `StringTrimBoth` - `StringTrimLeft` - `StringTrimRight` Changes: - `CollationSupport.java` - Add new `StringTrim`, `StringTrimLeft` and `StringTrimRight` classes with corresponding logic. - `CollationAwareUTF8String` - add new `trim`, `trimLeft` and `trimRight` methods that actually implement trim logic. - `UTF8String.java` - expose some of the methods publicly. - `stringExpressions.scala` - Change input types. - Change eval and code gen logic. - `CollationTypeCasts.scala` - add `StringTrim*` expressions to `CollationTypeCasts` rules. ### Why are the changes needed? We are incrementally adding collation support to a built-in string functions in Spark. ### Does this PR introduce _any_ user-facing change? Yes: - User should now be able to use non-default collations in string trim functions. ### How was this patch tested? Already existing tests + new unit/e2e tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46206 from davidm-db/string-trim-functions. Authored-by: David Milicevic Signed-off-by: Wenchen Fan --- .../util/CollationAwareUTF8String.java | 470 +++++++++++++++ .../sql/catalyst/util/CollationSupport.java | 534 +++++++----------- .../apache/spark/unsafe/types/UTF8String.java | 2 +- .../unsafe/types/CollationSupportSuite.java | 193 +++++++ .../analysis/CollationTypeCasts.scala | 2 +- .../expressions/stringExpressions.scala | 53 +- .../sql/CollationStringExpressionsSuite.scala | 161 +++++- 7 files changed, 1054 insertions(+), 361 deletions(-) create mode 100644 common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java new file mode 100644 index 0000000000000..ee0d611d7e652 --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java @@ -0,0 +1,470 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.util; + +import com.ibm.icu.lang.UCharacter; +import com.ibm.icu.text.BreakIterator; +import com.ibm.icu.text.StringSearch; +import com.ibm.icu.util.ULocale; + +import org.apache.spark.unsafe.UTF8StringBuilder; +import org.apache.spark.unsafe.types.UTF8String; + +import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET; +import static org.apache.spark.unsafe.Platform.copyMemory; + +import java.util.HashMap; +import java.util.Map; + +/** + * Utility class for collation-aware UTF8String operations. + */ +public class CollationAwareUTF8String { + public static UTF8String replace(final UTF8String src, final UTF8String search, + final UTF8String replace, final int collationId) { + // This collation aware implementation is based on existing implementation on UTF8String + if (src.numBytes() == 0 || search.numBytes() == 0) { + return src; + } + + StringSearch stringSearch = CollationFactory.getStringSearch(src, search, collationId); + + // Find the first occurrence of the search string. + int end = stringSearch.next(); + if (end == StringSearch.DONE) { + // Search string was not found, so string is unchanged. + return src; + } + + // Initialize byte positions + int c = 0; + int byteStart = 0; // position in byte + int byteEnd = 0; // position in byte + while (byteEnd < src.numBytes() && c < end) { + byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd)); + c += 1; + } + + // At least one match was found. Estimate space needed for result. + // The 16x multiplier here is chosen to match commons-lang3's implementation. + int increase = Math.max(0, Math.abs(replace.numBytes() - search.numBytes())) * 16; + final UTF8StringBuilder buf = new UTF8StringBuilder(src.numBytes() + increase); + while (end != StringSearch.DONE) { + buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, byteEnd - byteStart); + buf.append(replace); + + // Move byteStart to the beginning of the current match + byteStart = byteEnd; + int cs = c; + // Move cs to the end of the current match + // This is necessary because the search string may contain 'multi-character' characters + while (byteStart < src.numBytes() && cs < c + stringSearch.getMatchLength()) { + byteStart += UTF8String.numBytesForFirstByte(src.getByte(byteStart)); + cs += 1; + } + // Go to next match + end = stringSearch.next(); + // Update byte positions + while (byteEnd < src.numBytes() && c < end) { + byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd)); + c += 1; + } + } + buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, + src.numBytes() - byteStart); + return buf.build(); + } + + public static UTF8String lowercaseReplace(final UTF8String src, final UTF8String search, + final UTF8String replace) { + if (src.numBytes() == 0 || search.numBytes() == 0) { + return src; + } + UTF8String lowercaseString = src.toLowerCase(); + UTF8String lowercaseSearch = search.toLowerCase(); + + int start = 0; + int end = lowercaseString.indexOf(lowercaseSearch, 0); + if (end == -1) { + // Search string was not found, so string is unchanged. + return src; + } + + // Initialize byte positions + int c = 0; + int byteStart = 0; // position in byte + int byteEnd = 0; // position in byte + while (byteEnd < src.numBytes() && c < end) { + byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd)); + c += 1; + } + + // At least one match was found. Estimate space needed for result. + // The 16x multiplier here is chosen to match commons-lang3's implementation. + int increase = Math.max(0, replace.numBytes() - search.numBytes()) * 16; + final UTF8StringBuilder buf = new UTF8StringBuilder(src.numBytes() + increase); + while (end != -1) { + buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, byteEnd - byteStart); + buf.append(replace); + // Update character positions + start = end + lowercaseSearch.numChars(); + end = lowercaseString.indexOf(lowercaseSearch, start); + // Update byte positions + byteStart = byteEnd + search.numBytes(); + while (byteEnd < src.numBytes() && c < end) { + byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd)); + c += 1; + } + } + buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, + src.numBytes() - byteStart); + return buf.build(); + } + + public static String toUpperCase(final String target, final int collationId) { + ULocale locale = CollationFactory.fetchCollation(collationId) + .collator.getLocale(ULocale.ACTUAL_LOCALE); + return UCharacter.toUpperCase(locale, target); + } + + public static String toLowerCase(final String target, final int collationId) { + ULocale locale = CollationFactory.fetchCollation(collationId) + .collator.getLocale(ULocale.ACTUAL_LOCALE); + return UCharacter.toLowerCase(locale, target); + } + + public static String toTitleCase(final String target, final int collationId) { + ULocale locale = CollationFactory.fetchCollation(collationId) + .collator.getLocale(ULocale.ACTUAL_LOCALE); + return UCharacter.toTitleCase(locale, target, BreakIterator.getWordInstance(locale)); + } + + public static int findInSet(final UTF8String match, final UTF8String set, int collationId) { + if (match.contains(UTF8String.fromString(","))) { + return 0; + } + + String setString = set.toString(); + StringSearch stringSearch = CollationFactory.getStringSearch(setString, match.toString(), + collationId); + + int wordStart = 0; + while ((wordStart = stringSearch.next()) != StringSearch.DONE) { + boolean isValidStart = wordStart == 0 || setString.charAt(wordStart - 1) == ','; + boolean isValidEnd = wordStart + stringSearch.getMatchLength() == setString.length() + || setString.charAt(wordStart + stringSearch.getMatchLength()) == ','; + + if (isValidStart && isValidEnd) { + int pos = 0; + for (int i = 0; i < setString.length() && i < wordStart; i++) { + if (setString.charAt(i) == ',') { + pos++; + } + } + + return pos + 1; + } + } + + return 0; + } + + public static int indexOf(final UTF8String target, final UTF8String pattern, + final int start, final int collationId) { + if (pattern.numBytes() == 0) { + return 0; + } + + StringSearch stringSearch = CollationFactory.getStringSearch(target, pattern, collationId); + stringSearch.setIndex(start); + + return stringSearch.next(); + } + + public static int find(UTF8String target, UTF8String pattern, int start, + int collationId) { + assert (pattern.numBytes() > 0); + + StringSearch stringSearch = CollationFactory.getStringSearch(target, pattern, collationId); + // Set search start position (start from character at start position) + stringSearch.setIndex(target.bytePosToChar(start)); + + // Return either the byte position or -1 if not found + return target.charPosToByte(stringSearch.next()); + } + + public static UTF8String subStringIndex(final UTF8String string, final UTF8String delimiter, + int count, final int collationId) { + if (delimiter.numBytes() == 0 || count == 0 || string.numBytes() == 0) { + return UTF8String.EMPTY_UTF8; + } + if (count > 0) { + int idx = -1; + while (count > 0) { + idx = find(string, delimiter, idx + 1, collationId); + if (idx >= 0) { + count --; + } else { + // can not find enough delim + return string; + } + } + if (idx == 0) { + return UTF8String.EMPTY_UTF8; + } + byte[] bytes = new byte[idx]; + copyMemory(string.getBaseObject(), string.getBaseOffset(), bytes, BYTE_ARRAY_OFFSET, idx); + return UTF8String.fromBytes(bytes); + + } else { + count = -count; + + StringSearch stringSearch = CollationFactory + .getStringSearch(string, delimiter, collationId); + + int start = string.numChars() - 1; + int lastMatchLength = 0; + int prevStart = -1; + while (count > 0) { + stringSearch.reset(); + prevStart = -1; + int matchStart = stringSearch.next(); + lastMatchLength = stringSearch.getMatchLength(); + while (matchStart <= start) { + if (matchStart != StringSearch.DONE) { + // Found a match, update the start position + prevStart = matchStart; + matchStart = stringSearch.next(); + } else { + break; + } + } + + if (prevStart == -1) { + // can not find enough delim + return string; + } else { + start = prevStart - 1; + count--; + } + } + + int resultStart = prevStart + lastMatchLength; + if (resultStart == string.numChars()) { + return UTF8String.EMPTY_UTF8; + } + + return string.substring(resultStart, string.numChars()); + } + } + + public static UTF8String lowercaseSubStringIndex(final UTF8String string, + final UTF8String delimiter, int count) { + if (delimiter.numBytes() == 0 || count == 0) { + return UTF8String.EMPTY_UTF8; + } + + UTF8String lowercaseString = string.toLowerCase(); + UTF8String lowercaseDelimiter = delimiter.toLowerCase(); + + if (count > 0) { + int idx = -1; + while (count > 0) { + idx = lowercaseString.find(lowercaseDelimiter, idx + 1); + if (idx >= 0) { + count--; + } else { + // can not find enough delim + return string; + } + } + if (idx == 0) { + return UTF8String.EMPTY_UTF8; + } + byte[] bytes = new byte[idx]; + copyMemory(string.getBaseObject(), string.getBaseOffset(), bytes, BYTE_ARRAY_OFFSET, idx); + return UTF8String.fromBytes(bytes); + + } else { + int idx = string.numBytes() - delimiter.numBytes() + 1; + count = -count; + while (count > 0) { + idx = lowercaseString.rfind(lowercaseDelimiter, idx - 1); + if (idx >= 0) { + count--; + } else { + // can not find enough delim + return string; + } + } + if (idx + delimiter.numBytes() == string.numBytes()) { + return UTF8String.EMPTY_UTF8; + } + int size = string.numBytes() - delimiter.numBytes() - idx; + byte[] bytes = new byte[size]; + copyMemory(string.getBaseObject(), string.getBaseOffset() + idx + delimiter.numBytes(), + bytes, BYTE_ARRAY_OFFSET, size); + return UTF8String.fromBytes(bytes); + } + } + + public static Map getCollationAwareDict(UTF8String string, + Map dict, int collationId) { + String srcStr = string.toString(); + + Map collationAwareDict = new HashMap<>(); + for (String key : dict.keySet()) { + StringSearch stringSearch = + CollationFactory.getStringSearch(string, UTF8String.fromString(key), collationId); + + int pos = 0; + while ((pos = stringSearch.next()) != StringSearch.DONE) { + int codePoint = srcStr.codePointAt(pos); + int charCount = Character.charCount(codePoint); + String newKey = srcStr.substring(pos, pos + charCount); + + boolean exists = false; + for (String existingKey : collationAwareDict.keySet()) { + if (stringSearch.getCollator().compare(existingKey, newKey) == 0) { + collationAwareDict.put(newKey, collationAwareDict.get(existingKey)); + exists = true; + break; + } + } + + if (!exists) { + collationAwareDict.put(newKey, dict.get(key)); + } + } + } + + return collationAwareDict; + } + + public static UTF8String lowercaseTrim( + final UTF8String srcString, + final UTF8String trimString) { + // Matching UTF8String behavior for null `trimString`. + if (trimString == null) { + return null; + } + + UTF8String leftTrimmed = lowercaseTrimLeft(srcString, trimString); + return lowercaseTrimRight(leftTrimmed, trimString); + } + + public static UTF8String lowercaseTrimLeft( + final UTF8String srcString, + final UTF8String trimString) { + // Matching UTF8String behavior for null `trimString`. + if (trimString == null) { + return null; + } + + // The searching byte position in the srcString. + int searchIdx = 0; + // The byte position of a first non-matching character in the srcString. + int trimByteIdx = 0; + // Number of bytes in srcString. + int numBytes = srcString.numBytes(); + // Convert trimString to lowercase, so it can be searched properly. + UTF8String lowercaseTrimString = trimString.toLowerCase(); + + while (searchIdx < numBytes) { + UTF8String searchChar = srcString.copyUTF8String( + searchIdx, + searchIdx + UTF8String.numBytesForFirstByte(srcString.getByte(searchIdx)) - 1); + int searchCharBytes = searchChar.numBytes(); + + // Try to find the matching for the searchChar in the trimString. + if (lowercaseTrimString.find(searchChar.toLowerCase(), 0) >= 0) { + trimByteIdx += searchCharBytes; + searchIdx += searchCharBytes; + } else { + // No matching, exit the search. + break; + } + } + + if (searchIdx == 0) { + // Nothing trimmed - return original string (not converted to lowercase). + return srcString; + } + if (trimByteIdx >= numBytes) { + // Everything trimmed. + return UTF8String.EMPTY_UTF8; + } + return srcString.copyUTF8String(trimByteIdx, numBytes - 1); + } + + public static UTF8String lowercaseTrimRight( + final UTF8String srcString, + final UTF8String trimString) { + // Matching UTF8String behavior for null `trimString`. + if (trimString == null) { + return null; + } + + // Number of bytes iterated from the srcString. + int byteIdx = 0; + // Number of characters iterated from the srcString. + int numChars = 0; + // Number of bytes in srcString. + int numBytes = srcString.numBytes(); + // Array of character length for the srcString. + int[] stringCharLen = new int[numBytes]; + // Array of the first byte position for each character in the srcString. + int[] stringCharPos = new int[numBytes]; + // Convert trimString to lowercase, so it can be searched properly. + UTF8String lowercaseTrimString = trimString.toLowerCase(); + + // Build the position and length array. + while (byteIdx < numBytes) { + stringCharPos[numChars] = byteIdx; + stringCharLen[numChars] = UTF8String.numBytesForFirstByte(srcString.getByte(byteIdx)); + byteIdx += stringCharLen[numChars]; + numChars++; + } + + // Index trimEnd points to the first no matching byte position from the right side of + // the source string. + int trimByteIdx = numBytes - 1; + + while (numChars > 0) { + UTF8String searchChar = srcString.copyUTF8String( + stringCharPos[numChars - 1], + stringCharPos[numChars - 1] + stringCharLen[numChars - 1] - 1); + + if(lowercaseTrimString.find(searchChar.toLowerCase(), 0) >= 0) { + trimByteIdx -= stringCharLen[numChars - 1]; + numChars--; + } else { + break; + } + } + + if (trimByteIdx == numBytes - 1) { + // Nothing trimmed. + return srcString; + } + if (trimByteIdx < 0) { + // Everything trimmed. + return UTF8String.EMPTY_UTF8; + } + return srcString.copyUTF8String(0, trimByteIdx); + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index b77671cee90b0..bea3dc08b4489 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -16,23 +16,15 @@ */ package org.apache.spark.sql.catalyst.util; -import com.ibm.icu.lang.UCharacter; -import com.ibm.icu.text.BreakIterator; import com.ibm.icu.text.StringSearch; -import com.ibm.icu.util.ULocale; -import org.apache.spark.unsafe.UTF8StringBuilder; import org.apache.spark.unsafe.types.UTF8String; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.regex.Pattern; -import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET; -import static org.apache.spark.unsafe.Platform.copyMemory; - /** * Static entry point for collation-aware expressions (StringExpressions, RegexpExpressions, and * other expressions that require custom collation support), as well as private utility methods for @@ -441,7 +433,7 @@ public static int execLowercase(final UTF8String string, final UTF8String substr return string.toLowerCase().indexOf(substring.toLowerCase(), start); } public static int execICU(final UTF8String string, final UTF8String substring, final int start, - final int collationId) { + final int collationId) { return CollationAwareUTF8String.indexOf(string, substring, start, collationId); } } @@ -535,6 +527,201 @@ public static UTF8String execICU(final UTF8String source, Map di } } + public static class StringTrim { + public static UTF8String exec( + final UTF8String srcString, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsBinaryEquality) { + return execBinary(srcString); + } else { + return execLowercase(srcString); + } + } + public static UTF8String exec( + final UTF8String srcString, + final UTF8String trimString, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsBinaryEquality) { + return execBinary(srcString, trimString); + } else { + return execLowercase(srcString, trimString); + } + } + public static String genCode( + final String srcString, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + String expr = "CollationSupport.StringTrim.exec"; + if (collation.supportsBinaryEquality) { + return String.format(expr + "Binary(%s)", srcString); + } { + return String.format(expr + "Lowercase(%s)", srcString); + } + } + public static String genCode( + final String srcString, + final String trimString, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + String expr = "CollationSupport.StringTrim.exec"; + if (collation.supportsBinaryEquality) { + return String.format(expr + "Binary(%s, %s)", srcString, trimString); + } else { + return String.format(expr + "Lowercase(%s, %s)", srcString, trimString); + } + } + public static UTF8String execBinary( + final UTF8String srcString) { + return srcString.trim(); + } + public static UTF8String execBinary( + final UTF8String srcString, + final UTF8String trimString) { + return srcString.trim(trimString); + } + public static UTF8String execLowercase( + final UTF8String srcString) { + return srcString.trim(); + } + public static UTF8String execLowercase( + final UTF8String srcString, + final UTF8String trimString) { + return CollationAwareUTF8String.lowercaseTrim(srcString, trimString); + } + } + + public static class StringTrimLeft { + public static UTF8String exec( + final UTF8String srcString, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsBinaryEquality) { + return execBinary(srcString); + } else { + return execLowercase(srcString); + } + } + public static UTF8String exec( + final UTF8String srcString, + final UTF8String trimString, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsBinaryEquality) { + return execBinary(srcString, trimString); + } else { + return execLowercase(srcString, trimString); + } + } + public static String genCode( + final String srcString, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + String expr = "CollationSupport.StringTrimLeft.exec"; + if (collation.supportsBinaryEquality) { + return String.format(expr + "Binary(%s)", srcString); + } else { + return String.format(expr + "Lowercase(%s)", srcString); + } + } + public static String genCode( + final String srcString, + final String trimString, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + String expr = "CollationSupport.StringTrimLeft.exec"; + if (collation.supportsBinaryEquality) { + return String.format(expr + "Binary(%s, %s)", srcString, trimString); + } else { + return String.format(expr + "Lowercase(%s, %s)", srcString, trimString); + } + } + public static UTF8String execBinary( + final UTF8String srcString) { + return srcString.trimLeft(); + } + public static UTF8String execBinary( + final UTF8String srcString, + final UTF8String trimString) { + return srcString.trimLeft(trimString); + } + public static UTF8String execLowercase( + final UTF8String srcString) { + return srcString.trimLeft(); + } + public static UTF8String execLowercase( + final UTF8String srcString, + final UTF8String trimString) { + return CollationAwareUTF8String.lowercaseTrimLeft(srcString, trimString); + } + } + + public static class StringTrimRight { + public static UTF8String exec( + final UTF8String srcString, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsBinaryEquality) { + return execBinary(srcString); + } else { + return execLowercase(srcString); + } + } + public static UTF8String exec( + final UTF8String srcString, + final UTF8String trimString, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsBinaryEquality) { + return execBinary(srcString, trimString); + } else { + return execLowercase(srcString, trimString); + } + } + public static String genCode( + final String srcString, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + String expr = "CollationSupport.StringTrimRight.exec"; + if (collation.supportsBinaryEquality) { + return String.format(expr + "Binary(%s)", srcString); + } else { + return String.format(expr + "Lowercase(%s)", srcString); + } + } + public static String genCode( + final String srcString, + final String trimString, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + String expr = "CollationSupport.StringTrimRight.exec"; + if (collation.supportsBinaryEquality) { + return String.format(expr + "Binary(%s, %s)", srcString, trimString); + } else { + return String.format(expr + "Lowercase(%s, %s)", srcString, trimString); + } + } + public static UTF8String execBinary( + final UTF8String srcString) { + return srcString.trimRight(); + } + public static UTF8String execBinary( + final UTF8String srcString, + final UTF8String trimString) { + return srcString.trimRight(trimString); + } + public static UTF8String execLowercase( + final UTF8String srcString) { + return srcString.trimRight(); + } + public static UTF8String execLowercase( + final UTF8String srcString, + final UTF8String trimString) { + return CollationAwareUTF8String.lowercaseTrimRight(srcString, trimString); + } + } + // TODO: Add more collation-aware string expressions. /** @@ -566,333 +753,4 @@ public static UTF8String collationAwareRegex(final UTF8String regex, final int c // TODO: Add other collation-aware expressions. - /** - * Utility class for collation-aware UTF8String operations. - */ - - private static class CollationAwareUTF8String { - - private static UTF8String replace(final UTF8String src, final UTF8String search, - final UTF8String replace, final int collationId) { - // This collation aware implementation is based on existing implementation on UTF8String - if (src.numBytes() == 0 || search.numBytes() == 0) { - return src; - } - - StringSearch stringSearch = CollationFactory.getStringSearch(src, search, collationId); - - // Find the first occurrence of the search string. - int end = stringSearch.next(); - if (end == StringSearch.DONE) { - // Search string was not found, so string is unchanged. - return src; - } - - // Initialize byte positions - int c = 0; - int byteStart = 0; // position in byte - int byteEnd = 0; // position in byte - while (byteEnd < src.numBytes() && c < end) { - byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd)); - c += 1; - } - - // At least one match was found. Estimate space needed for result. - // The 16x multiplier here is chosen to match commons-lang3's implementation. - int increase = Math.max(0, Math.abs(replace.numBytes() - search.numBytes())) * 16; - final UTF8StringBuilder buf = new UTF8StringBuilder(src.numBytes() + increase); - while (end != StringSearch.DONE) { - buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, byteEnd - byteStart); - buf.append(replace); - - // Move byteStart to the beginning of the current match - byteStart = byteEnd; - int cs = c; - // Move cs to the end of the current match - // This is necessary because the search string may contain 'multi-character' characters - while (byteStart < src.numBytes() && cs < c + stringSearch.getMatchLength()) { - byteStart += UTF8String.numBytesForFirstByte(src.getByte(byteStart)); - cs += 1; - } - // Go to next match - end = stringSearch.next(); - // Update byte positions - while (byteEnd < src.numBytes() && c < end) { - byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd)); - c += 1; - } - } - buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, - src.numBytes() - byteStart); - return buf.build(); - } - - private static UTF8String lowercaseReplace(final UTF8String src, final UTF8String search, - final UTF8String replace) { - if (src.numBytes() == 0 || search.numBytes() == 0) { - return src; - } - UTF8String lowercaseString = src.toLowerCase(); - UTF8String lowercaseSearch = search.toLowerCase(); - - int start = 0; - int end = lowercaseString.indexOf(lowercaseSearch, 0); - if (end == -1) { - // Search string was not found, so string is unchanged. - return src; - } - - // Initialize byte positions - int c = 0; - int byteStart = 0; // position in byte - int byteEnd = 0; // position in byte - while (byteEnd < src.numBytes() && c < end) { - byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd)); - c += 1; - } - - // At least one match was found. Estimate space needed for result. - // The 16x multiplier here is chosen to match commons-lang3's implementation. - int increase = Math.max(0, replace.numBytes() - search.numBytes()) * 16; - final UTF8StringBuilder buf = new UTF8StringBuilder(src.numBytes() + increase); - while (end != -1) { - buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, byteEnd - byteStart); - buf.append(replace); - // Update character positions - start = end + lowercaseSearch.numChars(); - end = lowercaseString.indexOf(lowercaseSearch, start); - // Update byte positions - byteStart = byteEnd + search.numBytes(); - while (byteEnd < src.numBytes() && c < end) { - byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd)); - c += 1; - } - } - buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, - src.numBytes() - byteStart); - return buf.build(); - } - - private static String toUpperCase(final String target, final int collationId) { - ULocale locale = CollationFactory.fetchCollation(collationId) - .collator.getLocale(ULocale.ACTUAL_LOCALE); - return UCharacter.toUpperCase(locale, target); - } - - private static String toLowerCase(final String target, final int collationId) { - ULocale locale = CollationFactory.fetchCollation(collationId) - .collator.getLocale(ULocale.ACTUAL_LOCALE); - return UCharacter.toLowerCase(locale, target); - } - - private static String toTitleCase(final String target, final int collationId) { - ULocale locale = CollationFactory.fetchCollation(collationId) - .collator.getLocale(ULocale.ACTUAL_LOCALE); - return UCharacter.toTitleCase(locale, target, BreakIterator.getWordInstance(locale)); - } - - private static int findInSet(final UTF8String match, final UTF8String set, int collationId) { - if (match.contains(UTF8String.fromString(","))) { - return 0; - } - - String setString = set.toString(); - StringSearch stringSearch = CollationFactory.getStringSearch(setString, match.toString(), - collationId); - - int wordStart = 0; - while ((wordStart = stringSearch.next()) != StringSearch.DONE) { - boolean isValidStart = wordStart == 0 || setString.charAt(wordStart - 1) == ','; - boolean isValidEnd = wordStart + stringSearch.getMatchLength() == setString.length() - || setString.charAt(wordStart + stringSearch.getMatchLength()) == ','; - - if (isValidStart && isValidEnd) { - int pos = 0; - for (int i = 0; i < setString.length() && i < wordStart; i++) { - if (setString.charAt(i) == ',') { - pos++; - } - } - - return pos + 1; - } - } - - return 0; - } - - private static int indexOf(final UTF8String target, final UTF8String pattern, - final int start, final int collationId) { - if (pattern.numBytes() == 0) { - return 0; - } - - StringSearch stringSearch = CollationFactory.getStringSearch(target, pattern, collationId); - stringSearch.setIndex(start); - - return stringSearch.next(); - } - - private static int find(UTF8String target, UTF8String pattern, int start, - int collationId) { - assert (pattern.numBytes() > 0); - - StringSearch stringSearch = CollationFactory.getStringSearch(target, pattern, collationId); - // Set search start position (start from character at start position) - stringSearch.setIndex(target.bytePosToChar(start)); - - // Return either the byte position or -1 if not found - return target.charPosToByte(stringSearch.next()); - } - - private static UTF8String subStringIndex(final UTF8String string, final UTF8String delimiter, - int count, final int collationId) { - if (delimiter.numBytes() == 0 || count == 0 || string.numBytes() == 0) { - return UTF8String.EMPTY_UTF8; - } - if (count > 0) { - int idx = -1; - while (count > 0) { - idx = find(string, delimiter, idx + 1, collationId); - if (idx >= 0) { - count --; - } else { - // can not find enough delim - return string; - } - } - if (idx == 0) { - return UTF8String.EMPTY_UTF8; - } - byte[] bytes = new byte[idx]; - copyMemory(string.getBaseObject(), string.getBaseOffset(), bytes, BYTE_ARRAY_OFFSET, idx); - return UTF8String.fromBytes(bytes); - - } else { - count = -count; - - StringSearch stringSearch = CollationFactory - .getStringSearch(string, delimiter, collationId); - - int start = string.numChars() - 1; - int lastMatchLength = 0; - int prevStart = -1; - while (count > 0) { - stringSearch.reset(); - prevStart = -1; - int matchStart = stringSearch.next(); - lastMatchLength = stringSearch.getMatchLength(); - while (matchStart <= start) { - if (matchStart != StringSearch.DONE) { - // Found a match, update the start position - prevStart = matchStart; - matchStart = stringSearch.next(); - } else { - break; - } - } - - if (prevStart == -1) { - // can not find enough delim - return string; - } else { - start = prevStart - 1; - count--; - } - } - - int resultStart = prevStart + lastMatchLength; - if (resultStart == string.numChars()) { - return UTF8String.EMPTY_UTF8; - } - - return string.substring(resultStart, string.numChars()); - } - } - - private static UTF8String lowercaseSubStringIndex(final UTF8String string, - final UTF8String delimiter, int count) { - if (delimiter.numBytes() == 0 || count == 0) { - return UTF8String.EMPTY_UTF8; - } - - UTF8String lowercaseString = string.toLowerCase(); - UTF8String lowercaseDelimiter = delimiter.toLowerCase(); - - if (count > 0) { - int idx = -1; - while (count > 0) { - idx = lowercaseString.find(lowercaseDelimiter, idx + 1); - if (idx >= 0) { - count --; - } else { - // can not find enough delim - return string; - } - } - if (idx == 0) { - return UTF8String.EMPTY_UTF8; - } - byte[] bytes = new byte[idx]; - copyMemory(string.getBaseObject(), string.getBaseOffset(), bytes, BYTE_ARRAY_OFFSET, idx); - return UTF8String.fromBytes(bytes); - - } else { - int idx = string.numBytes() - delimiter.numBytes() + 1; - count = -count; - while (count > 0) { - idx = lowercaseString.rfind(lowercaseDelimiter, idx - 1); - if (idx >= 0) { - count --; - } else { - // can not find enough delim - return string; - } - } - if (idx + delimiter.numBytes() == string.numBytes()) { - return UTF8String.EMPTY_UTF8; - } - int size = string.numBytes() - delimiter.numBytes() - idx; - byte[] bytes = new byte[size]; - copyMemory(string.getBaseObject(), string.getBaseOffset() + idx + delimiter.numBytes(), - bytes, BYTE_ARRAY_OFFSET, size); - return UTF8String.fromBytes(bytes); - } - } - - private static Map getCollationAwareDict(UTF8String string, - Map dict, int collationId) { - String srcStr = string.toString(); - - Map collationAwareDict = new HashMap<>(); - for (String key : dict.keySet()) { - StringSearch stringSearch = - CollationFactory.getStringSearch(string, UTF8String.fromString(key), collationId); - - int pos = 0; - while ((pos = stringSearch.next()) != StringSearch.DONE) { - int codePoint = srcStr.codePointAt(pos); - int charCount = Character.charCount(codePoint); - String newKey = srcStr.substring(pos, pos + charCount); - - boolean exists = false; - for (String existingKey : collationAwareDict.keySet()) { - if (stringSearch.getCollator().compare(existingKey, newKey) == 0) { - collationAwareDict.put(newKey, collationAwareDict.get(existingKey)); - exists = true; - break; - } - } - - if (!exists) { - collationAwareDict.put(newKey, dict.get(key)); - } - } - } - - return collationAwareDict; - } - - } - } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 2a5d145803533..20b26b6ebc5a5 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -646,7 +646,7 @@ public int findInSet(UTF8String match) { * @param end the end position of the current UTF8String in bytes. * @return a new UTF8String in the position of [start, end] of current UTF8String bytes. */ - private UTF8String copyUTF8String(int start, int end) { + public UTF8String copyUTF8String(int start, int end) { int len = end - start + 1; byte[] newBytes = new byte[len]; copyMemory(base, offset + start, newBytes, BYTE_ARRAY_OFFSET, len); diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index 2f05b9ad88c9c..7fc3c4e349c3b 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -800,6 +800,199 @@ public void testSubstringIndex() throws SparkException { assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "i̇o", -4, "UNICODE_CI", "i̇o12i̇oİo"); } + private void assertStringTrim( + String collation, + String sourceString, + String trimString, + String expectedResultString) throws SparkException { + int collationId = CollationFactory.collationNameToId(collation); + String result; + + if (trimString == null) { + result = CollationSupport.StringTrim.exec( + UTF8String.fromString(sourceString), collationId).toString(); + } else { + result = CollationSupport.StringTrim.exec( + UTF8String + .fromString(sourceString), UTF8String.fromString(trimString), collationId) + .toString(); + } + + assertEquals(expectedResultString, result); + } + + private void assertStringTrimLeft( + String collation, + String sourceString, + String trimString, + String expectedResultString) throws SparkException { + int collationId = CollationFactory.collationNameToId(collation); + String result; + + if (trimString == null) { + result = CollationSupport.StringTrimLeft.exec( + UTF8String.fromString(sourceString), collationId).toString(); + } else { + result = CollationSupport.StringTrimLeft.exec( + UTF8String + .fromString(sourceString), UTF8String.fromString(trimString), collationId) + .toString(); + } + + assertEquals(expectedResultString, result); + } + + private void assertStringTrimRight( + String collation, + String sourceString, + String trimString, + String expectedResultString) throws SparkException { + int collationId = CollationFactory.collationNameToId(collation); + String result; + + if (trimString == null) { + result = CollationSupport.StringTrimRight.exec( + UTF8String.fromString(sourceString), collationId).toString(); + } else { + result = CollationSupport.StringTrimRight.exec( + UTF8String + .fromString(sourceString), UTF8String.fromString(trimString), collationId) + .toString(); + } + + assertEquals(expectedResultString, result); + } + + @Test + public void testStringTrim() throws SparkException { + assertStringTrim("UTF8_BINARY", "asd", null, "asd"); + assertStringTrim("UTF8_BINARY", " asd ", null, "asd"); + assertStringTrim("UTF8_BINARY", " a世a ", null, "a世a"); + assertStringTrim("UTF8_BINARY", "asd", "x", "asd"); + assertStringTrim("UTF8_BINARY", "xxasdxx", "x", "asd"); + assertStringTrim("UTF8_BINARY", "xa世ax", "x", "a世a"); + + assertStringTrimLeft("UTF8_BINARY", "asd", null, "asd"); + assertStringTrimLeft("UTF8_BINARY", " asd ", null, "asd "); + assertStringTrimLeft("UTF8_BINARY", " a世a ", null, "a世a "); + assertStringTrimLeft("UTF8_BINARY", "asd", "x", "asd"); + assertStringTrimLeft("UTF8_BINARY", "xxasdxx", "x", "asdxx"); + assertStringTrimLeft("UTF8_BINARY", "xa世ax", "x", "a世ax"); + + assertStringTrimRight("UTF8_BINARY", "asd", null, "asd"); + assertStringTrimRight("UTF8_BINARY", " asd ", null, " asd"); + assertStringTrimRight("UTF8_BINARY", " a世a ", null, " a世a"); + assertStringTrimRight("UTF8_BINARY", "asd", "x", "asd"); + assertStringTrimRight("UTF8_BINARY", "xxasdxx", "x", "xxasd"); + assertStringTrimRight("UTF8_BINARY", "xa世ax", "x", "xa世a"); + + assertStringTrim("UTF8_BINARY_LCASE", "asd", null, "asd"); + assertStringTrim("UTF8_BINARY_LCASE", " asd ", null, "asd"); + assertStringTrim("UTF8_BINARY_LCASE", " a世a ", null, "a世a"); + assertStringTrim("UTF8_BINARY_LCASE", "asd", "x", "asd"); + assertStringTrim("UTF8_BINARY_LCASE", "xxasdxx", "x", "asd"); + assertStringTrim("UTF8_BINARY_LCASE", "xa世ax", "x", "a世a"); + + assertStringTrimLeft("UTF8_BINARY_LCASE", "asd", null, "asd"); + assertStringTrimLeft("UTF8_BINARY_LCASE", " asd ", null, "asd "); + assertStringTrimLeft("UTF8_BINARY_LCASE", " a世a ", null, "a世a "); + assertStringTrimLeft("UTF8_BINARY_LCASE", "asd", "x", "asd"); + assertStringTrimLeft("UTF8_BINARY_LCASE", "xxasdxx", "x", "asdxx"); + assertStringTrimLeft("UTF8_BINARY_LCASE", "xa世ax", "x", "a世ax"); + + assertStringTrimRight("UTF8_BINARY_LCASE", "asd", null, "asd"); + assertStringTrimRight("UTF8_BINARY_LCASE", " asd ", null, " asd"); + assertStringTrimRight("UTF8_BINARY_LCASE", " a世a ", null, " a世a"); + assertStringTrimRight("UTF8_BINARY_LCASE", "asd", "x", "asd"); + assertStringTrimRight("UTF8_BINARY_LCASE", "xxasdxx", "x", "xxasd"); + assertStringTrimRight("UTF8_BINARY_LCASE", "xa世ax", "x", "xa世a"); + + assertStringTrim("UTF8_BINARY_LCASE", "asd", null, "asd"); + assertStringTrim("UTF8_BINARY_LCASE", " asd ", null, "asd"); + assertStringTrim("UTF8_BINARY_LCASE", " a世a ", null, "a世a"); + assertStringTrim("UTF8_BINARY_LCASE", "asd", "x", "asd"); + assertStringTrim("UTF8_BINARY_LCASE", "xxasdxx", "x", "asd"); + assertStringTrim("UTF8_BINARY_LCASE", "xa世ax", "x", "a世a"); + + assertStringTrimLeft("UNICODE", "asd", null, "asd"); + assertStringTrimLeft("UNICODE", " asd ", null, "asd "); + assertStringTrimLeft("UNICODE", " a世a ", null, "a世a "); + assertStringTrimLeft("UNICODE", "asd", "x", "asd"); + assertStringTrimLeft("UNICODE", "xxasdxx", "x", "asdxx"); + assertStringTrimLeft("UNICODE", "xa世ax", "x", "a世ax"); + + assertStringTrimRight("UNICODE", "asd", null, "asd"); + assertStringTrimRight("UNICODE", " asd ", null, " asd"); + assertStringTrimRight("UNICODE", " a世a ", null, " a世a"); + assertStringTrimRight("UNICODE", "asd", "x", "asd"); + assertStringTrimRight("UNICODE", "xxasdxx", "x", "xxasd"); + assertStringTrimRight("UNICODE", "xa世ax", "x", "xa世a"); + + // Test cases where trimString has more than one character + assertStringTrim("UTF8_BINARY", "ddsXXXaa", "asd", "XXX"); + assertStringTrimLeft("UTF8_BINARY", "ddsXXXaa", "asd", "XXXaa"); + assertStringTrimRight("UTF8_BINARY", "ddsXXXaa", "asd", "ddsXXX"); + + assertStringTrim("UTF8_BINARY_LCASE", "ddsXXXaa", "asd", "XXX"); + assertStringTrimLeft("UTF8_BINARY_LCASE", "ddsXXXaa", "asd", "XXXaa"); + assertStringTrimRight("UTF8_BINARY_LCASE", "ddsXXXaa", "asd", "ddsXXX"); + + assertStringTrim("UNICODE", "ddsXXXaa", "asd", "XXX"); + assertStringTrimLeft("UNICODE", "ddsXXXaa", "asd", "XXXaa"); + assertStringTrimRight("UNICODE", "ddsXXXaa", "asd", "ddsXXX"); + + // Test cases specific to collation type + // uppercase trim, lowercase src + assertStringTrim("UTF8_BINARY", "asd", "A", "asd"); + assertStringTrim("UTF8_BINARY_LCASE", "asd", "A", "sd"); + assertStringTrim("UNICODE", "asd", "A", "asd"); + assertStringTrim("UNICODE_CI", "asd", "A", "sd"); + + // lowercase trim, uppercase src + assertStringTrim("UTF8_BINARY", "ASD", "a", "ASD"); + assertStringTrim("UTF8_BINARY_LCASE", "ASD", "a", "SD"); + assertStringTrim("UNICODE", "ASD", "a", "ASD"); + assertStringTrim("UNICODE_CI", "ASD", "a", "SD"); + + // uppercase and lowercase chars of different byte-length (utf8) + assertStringTrim("UTF8_BINARY", "ẞaaaẞ", "ß", "ẞaaaẞ"); + assertStringTrimLeft("UTF8_BINARY", "ẞaaaẞ", "ß", "ẞaaaẞ"); + assertStringTrimRight("UTF8_BINARY", "ẞaaaẞ", "ß", "ẞaaaẞ"); + + assertStringTrim("UTF8_BINARY_LCASE", "ẞaaaẞ", "ß", "aaa"); + assertStringTrimLeft("UTF8_BINARY_LCASE", "ẞaaaẞ", "ß", "aaaẞ"); + assertStringTrimRight("UTF8_BINARY_LCASE", "ẞaaaẞ", "ß", "ẞaaa"); + + assertStringTrim("UNICODE", "ẞaaaẞ", "ß", "ẞaaaẞ"); + assertStringTrimLeft("UNICODE", "ẞaaaẞ", "ß", "ẞaaaẞ"); + assertStringTrimRight("UNICODE", "ẞaaaẞ", "ß", "ẞaaaẞ"); + + assertStringTrim("UTF8_BINARY", "ßaaaß", "ẞ", "ßaaaß"); + assertStringTrimLeft("UTF8_BINARY", "ßaaaß", "ẞ", "ßaaaß"); + assertStringTrimRight("UTF8_BINARY", "ßaaaß", "ẞ", "ßaaaß"); + + assertStringTrim("UTF8_BINARY_LCASE", "ßaaaß", "ẞ", "aaa"); + assertStringTrimLeft("UTF8_BINARY_LCASE", "ßaaaß", "ẞ", "aaaß"); + assertStringTrimRight("UTF8_BINARY_LCASE", "ßaaaß", "ẞ", "ßaaa"); + + assertStringTrim("UNICODE", "ßaaaß", "ẞ", "ßaaaß"); + assertStringTrimLeft("UNICODE", "ßaaaß", "ẞ", "ßaaaß"); + assertStringTrimRight("UNICODE", "ßaaaß", "ẞ", "ßaaaß"); + + // different byte-length (utf8) chars trimmed + assertStringTrim("UTF8_BINARY", "Ëaaaẞ", "Ëẞ", "aaa"); + assertStringTrimLeft("UTF8_BINARY", "Ëaaaẞ", "Ëẞ", "aaaẞ"); + assertStringTrimRight("UTF8_BINARY", "Ëaaaẞ", "Ëẞ", "Ëaaa"); + + assertStringTrim("UTF8_BINARY_LCASE", "Ëaaaẞ", "Ëẞ", "aaa"); + assertStringTrimLeft("UTF8_BINARY_LCASE", "Ëaaaẞ", "Ëẞ", "aaaẞ"); + assertStringTrimRight("UTF8_BINARY_LCASE", "Ëaaaẞ", "Ëẞ", "Ëaaa"); + + assertStringTrim("UNICODE", "Ëaaaẞ", "Ëẞ", "aaa"); + assertStringTrimLeft("UNICODE", "Ëaaaẞ", "Ëẞ", "aaaẞ"); + assertStringTrimRight("UNICODE", "Ëaaaẞ", "Ëẞ", "Ëaaa"); + } + // TODO: Test more collation-aware string expressions. /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 44349384187ef..a50dad7c8cdb8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -74,7 +74,7 @@ object CollationTypeCasts extends TypeCoercionRule { case otherExpr @ ( _: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: Greatest | _: Least | _: Coalesce | _: BinaryExpression | _: ConcatWs | _: Mask | _: StringReplace | - _: StringTranslate) => + _: StringTranslate | _: StringTrim | _: StringTrimLeft | _: StringTrimRight) => val newChildren = collateToSingleType(otherExpr.children) otherExpr.withNewChildren(newChildren) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 0bdd7930b0bf9..09ec501311ade 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LO import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, CollationSupport, GenericArrayData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation} +import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation, StringTypeBinaryLcase} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.UTF8StringBuilder import org.apache.spark.unsafe.array.ByteArrayMethods @@ -1020,8 +1020,10 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes { protected def direction: String override def children: Seq[Expression] = srcStr +: trimStr.toSeq - override def dataType: DataType = StringType - override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) + override def dataType: DataType = srcStr.dataType + override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringTypeBinaryLcase) + + final lazy val collationId: Int = srcStr.dataType.asInstanceOf[StringType].collationId override def nullable: Boolean = children.exists(_.nullable) override def foldable: Boolean = children.forall(_.foldable) @@ -1040,13 +1042,19 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes { } } - protected val trimMethod: String - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evals = children.map(_.genCode(ctx)) - val srcString = evals(0) + val srcString = evals.head if (evals.length == 1) { + val stringTrimCode: String = this match { + case _: StringTrim => + CollationSupport.StringTrim.genCode(srcString.value, collationId) + case _: StringTrimLeft => + CollationSupport.StringTrimLeft.genCode(srcString.value, collationId) + case _: StringTrimRight => + CollationSupport.StringTrimRight.genCode(srcString.value, collationId) + } ev.copy(code = code""" |${srcString.code} |boolean ${ev.isNull} = false; @@ -1054,10 +1062,18 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes { |if (${srcString.isNull}) { | ${ev.isNull} = true; |} else { - | ${ev.value} = ${srcString.value}.$trimMethod(); + | ${ev.value} = $stringTrimCode; |}""".stripMargin) } else { val trimString = evals(1) + val stringTrimCode: String = this match { + case _: StringTrim => + CollationSupport.StringTrim.genCode(srcString.value, trimString.value, collationId) + case _: StringTrimLeft => + CollationSupport.StringTrimLeft.genCode(srcString.value, trimString.value, collationId) + case _: StringTrimRight => + CollationSupport.StringTrimRight.genCode(srcString.value, trimString.value, collationId) + } ev.copy(code = code""" |${srcString.code} |boolean ${ev.isNull} = false; @@ -1069,7 +1085,7 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes { | if (${trimString.isNull}) { | ${ev.isNull} = true; | } else { - | ${ev.value} = ${srcString.value}.$trimMethod(${trimString.value}); + | ${ev.value} = $stringTrimCode; | } |}""".stripMargin) } @@ -1162,12 +1178,11 @@ case class StringTrim(srcStr: Expression, trimStr: Option[Expression] = None) override protected def direction: String = "BOTH" - override def doEval(srcString: UTF8String): UTF8String = srcString.trim() + override def doEval(srcString: UTF8String): UTF8String = + CollationSupport.StringTrim.exec(srcString, collationId) override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = - srcString.trim(trimString) - - override val trimMethod: String = "trim" + CollationSupport.StringTrim.exec(srcString, trimString, collationId) override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy( @@ -1270,12 +1285,11 @@ case class StringTrimLeft(srcStr: Expression, trimStr: Option[Expression] = None override protected def direction: String = "LEADING" - override def doEval(srcString: UTF8String): UTF8String = srcString.trimLeft() + override def doEval(srcString: UTF8String): UTF8String = + CollationSupport.StringTrimLeft.exec(srcString, collationId) override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = - srcString.trimLeft(trimString) - - override val trimMethod: String = "trimLeft" + CollationSupport.StringTrimLeft.exec(srcString, trimString, collationId) override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): StringTrimLeft = @@ -1331,12 +1345,11 @@ case class StringTrimRight(srcStr: Expression, trimStr: Option[Expression] = Non override protected def direction: String = "TRAILING" - override def doEval(srcString: UTF8String): UTF8String = srcString.trimRight() + override def doEval(srcString: UTF8String): UTF8String = + CollationSupport.StringTrimRight.exec(srcString, collationId) override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = - srcString.trimRight(trimString) - - override val trimMethod: String = "trimRight" + CollationSupport.StringTrimRight.exec(srcString, trimString, collationId) override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): StringTrimRight = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index b9a4fecd0465b..9cc123b708aff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import org.apache.spark.SparkConf +import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, Literal, StringTrim, StringTrimLeft, StringTrimRight} import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -26,7 +27,8 @@ import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataType, // scalastyle:off nonascii class CollationStringExpressionsSuite extends QueryTest - with SharedSparkSession { + with SharedSparkSession + with ExpressionEvalHelper { test("Support ConcatWs string expression with collation") { // Supported collations @@ -800,6 +802,163 @@ class CollationStringExpressionsSuite assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") } + test("StringTrim* functions - unit tests for both paths (codegen and eval)") { + // Without trimString param. + checkEvaluation(StringTrim(Literal.create( " asd ", StringType("UTF8_BINARY"))), "asd") + checkEvaluation( + StringTrimLeft(Literal.create(" asd ", StringType("UTF8_BINARY_LCASE"))), "asd ") + checkEvaluation(StringTrimRight(Literal.create(" asd ", StringType("UNICODE"))), " asd") + + // With trimString param. + checkEvaluation( + StringTrim( + Literal.create(" asd ", StringType("UTF8_BINARY")), + Literal.create(" ", StringType("UTF8_BINARY"))), + "asd") + checkEvaluation( + StringTrimLeft( + Literal.create(" asd ", StringType("UTF8_BINARY_LCASE")), + Literal.create(" ", StringType("UTF8_BINARY_LCASE"))), + "asd ") + checkEvaluation( + StringTrimRight( + Literal.create(" asd ", StringType("UNICODE")), + Literal.create(" ", StringType("UNICODE"))), + " asd") + + checkEvaluation( + StringTrim( + Literal.create("xxasdxx", StringType("UTF8_BINARY")), + Literal.create("x", StringType("UTF8_BINARY"))), + "asd") + checkEvaluation( + StringTrimLeft( + Literal.create("xxasdxx", StringType("UTF8_BINARY_LCASE")), + Literal.create("x", StringType("UTF8_BINARY_LCASE"))), + "asdxx") + checkEvaluation( + StringTrimRight( + Literal.create("xxasdxx", StringType("UNICODE")), + Literal.create("x", StringType("UNICODE"))), + "xxasd") + } + + test("StringTrim* functions - E2E tests") { + case class StringTrimTestCase( + collation: String, + trimFunc: String, + sourceString: String, + hasTrimString: Boolean, + trimString: String, + expectedResultString: String) + + val testCases = Seq( + StringTrimTestCase("UTF8_BINARY", "TRIM", " asd ", false, null, "asd"), + StringTrimTestCase("UTF8_BINARY", "BTRIM", " asd ", true, null, null), + StringTrimTestCase("UTF8_BINARY", "LTRIM", "xxasdxx", true, "x", "asdxx"), + StringTrimTestCase("UTF8_BINARY", "RTRIM", "xxasdxx", true, "x", "xxasd"), + + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", " asd ", true, null, null), + StringTrimTestCase("UTF8_BINARY_LCASE", "BTRIM", "xxasdxx", true, "x", "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "LTRIM", "xxasdxx", true, "x", "asdxx"), + StringTrimTestCase("UTF8_BINARY_LCASE", "RTRIM", " asd ", false, null, " asd"), + + StringTrimTestCase("UNICODE", "TRIM", "xxasdxx", true, "x", "asd"), + StringTrimTestCase("UNICODE", "BTRIM", "xxasdxx", true, "x", "asd"), + StringTrimTestCase("UNICODE", "LTRIM", " asd ", false, null, "asd "), + StringTrimTestCase("UNICODE", "RTRIM", " asd ", true, null, null) + + // Other more complex cases can be found in unit tests in CollationSupportSuite.java. + ) + + testCases.foreach(testCase => { + var df: DataFrame = null + + if (testCase.trimFunc.equalsIgnoreCase("BTRIM")) { + // BTRIM has arguments in (srcStr, trimStr) order + df = sql(s"SELECT ${testCase.trimFunc}(" + + s"COLLATE('${testCase.sourceString}', '${testCase.collation}')" + + (if (!testCase.hasTrimString) "" + else if (testCase.trimString == null) ", null" + else s", '${testCase.trimString}'") + + ")") + } + else { + // While other functions have arguments in (trimStr, srcStr) order + df = sql(s"SELECT ${testCase.trimFunc}(" + + (if (!testCase.hasTrimString) "" + else if (testCase.trimString == null) "null, " + else s"'${testCase.trimString}', ") + + s"COLLATE('${testCase.sourceString}', '${testCase.collation}')" + + ")") + } + + checkAnswer(df = df, expectedAnswer = Row(testCase.expectedResultString)) + }) + } + + test("StringTrim* functions - implicit collations") { + checkAnswer( + df = sql("SELECT TRIM(COLLATE('x', 'UTF8_BINARY'), COLLATE('xax', 'UTF8_BINARY'))"), + expectedAnswer = Row("a")) + checkAnswer( + df = sql("SELECT BTRIM(COLLATE('xax', 'UTF8_BINARY_LCASE'), " + + "COLLATE('x', 'UTF8_BINARY_LCASE'))"), + expectedAnswer = Row("a")) + checkAnswer( + df = sql("SELECT LTRIM(COLLATE('x', 'UNICODE'), COLLATE('xax', 'UNICODE'))"), + expectedAnswer = Row("ax")) + + checkAnswer( + df = sql("SELECT RTRIM('x', COLLATE('xax', 'UTF8_BINARY'))"), + expectedAnswer = Row("xa")) + checkAnswer( + df = sql("SELECT TRIM('x', COLLATE('xax', 'UTF8_BINARY_LCASE'))"), + expectedAnswer = Row("a")) + checkAnswer( + df = sql("SELECT BTRIM('xax', COLLATE('x', 'UNICODE'))"), + expectedAnswer = Row("a")) + + checkAnswer( + df = sql("SELECT LTRIM(COLLATE('x', 'UTF8_BINARY'), 'xax')"), + expectedAnswer = Row("ax")) + checkAnswer( + df = sql("SELECT RTRIM(COLLATE('x', 'UTF8_BINARY_LCASE'), 'xax')"), + expectedAnswer = Row("xa")) + checkAnswer( + df = sql("SELECT TRIM(COLLATE('x', 'UNICODE'), 'xax')"), + expectedAnswer = Row("a")) + } + + test("StringTrim* functions - collation type mismatch") { + List("TRIM", "LTRIM", "RTRIM").foreach(func => { + val collationMismatch = intercept[AnalysisException] { + sql("SELECT " + func + "(COLLATE('x', 'UTF8_BINARY_LCASE'), " + + "COLLATE('xxaaaxx', 'UNICODE'))") + } + assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") + }) + + val collationMismatch = intercept[AnalysisException] { + sql("SELECT BTRIM(COLLATE('xxaaaxx', 'UNICODE'), COLLATE('x', 'UTF8_BINARY_LCASE'))") + } + assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") + } + + test("StringTrim* functions - unsupported collation types") { + List("TRIM", "LTRIM", "RTRIM").foreach(func => { + val collationMismatch = intercept[AnalysisException] { + sql("SELECT " + func + "(COLLATE('x', 'UNICODE_CI'), COLLATE('xxaaaxx', 'UNICODE_CI'))") + } + assert(collationMismatch.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") + }) + + val collationMismatch = intercept[AnalysisException] { + sql("SELECT BTRIM(COLLATE('xxaaaxx', 'UNICODE_CI'), COLLATE('x', 'UNICODE_CI'))") + } + assert(collationMismatch.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") + } + // TODO: Add more tests for other string expressions } From e1fb1d7e063af7e8eb6e992c800902aff6e19e15 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 9 May 2024 08:37:07 -0700 Subject: [PATCH 31/65] [SPARK-48216][TESTS] Remove overrides DockerJDBCIntegrationSuite.connectionTimeout to make related tests configurable ### What changes were proposed in this pull request? This PR removes overrides DockerJDBCIntegrationSuite.connectionTimeout to make related tests configurable. ### Why are the changes needed? The db dockers might require more time to bootstrap sometimes. It shall be configurable to avoid failure like: ```scala [info] org.apache.spark.sql.jdbc.DB2IntegrationSuite *** ABORTED *** (3 minutes, 11 seconds) [info] The code passed to eventually never returned normally. Attempted 96 times over 3.003998157633333 minutes. Last failure message: [jcc][t4][2030][11211][4.33.31] A communication error occurred during operations on the connection's underlying socket, socket input stream, [info] or socket output stream. Error location: Reply.fill() - insufficient data (-1). Message: Insufficient data. ERRORCODE=-4499, SQLSTATE=08001. (DockerJDBCIntegrationSuite.scala:215) [info] org.scalatest.exceptions.TestFailedDueToTimeoutException: [info] at org.scalatest.enablers.Retrying$$anon$4.tryTryAgain$2(Retrying.scala:219) [info] at org.scalatest.enablers.Retrying$$anon$4.retry(Retrying.scala:226) [info] at org.scalatest.concurrent.Eventually.eventually(Eventually.scala:313) [info] at org.scalatest.concurrent.Eventually.eventually$(Eventually.scala:312) ``` ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? Passing GA ### Was this patch authored or co-authored using generative AI tooling? no Closes #46505 from yaooqinn/SPARK-48216. Authored-by: Kent Yao Signed-off-by: Dongjoon Hyun --- .../scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala | 4 ---- .../org/apache/spark/sql/jdbc/DB2KrbIntegrationSuite.scala | 3 --- .../org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala | 4 ---- .../org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala | 3 --- .../spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala | 4 ---- .../org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala | 4 ---- .../org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala | 4 ---- 7 files changed, 26 deletions(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala index aca174cce1949..4ece4d2088f4b 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala @@ -21,8 +21,6 @@ import java.math.BigDecimal import java.sql.{Connection, Date, Timestamp} import java.util.Properties -import org.scalatest.time.SpanSugar._ - import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.sql.internal.SQLConf @@ -41,8 +39,6 @@ import org.apache.spark.tags.DockerTest class DB2IntegrationSuite extends DockerJDBCIntegrationSuite { override val db = new DB2DatabaseOnDocker - override val connectionTimeout = timeout(3.minutes) - override def dataPreparation(conn: Connection): Unit = { conn.prepareStatement("CREATE TABLE tbl (x INTEGER, y VARCHAR(8))").executeUpdate() conn.prepareStatement("INSERT INTO tbl VALUES (42,'fred')").executeUpdate() diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2KrbIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2KrbIntegrationSuite.scala index abb683c064955..4899de2b2a14c 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2KrbIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2KrbIntegrationSuite.scala @@ -24,7 +24,6 @@ import javax.security.auth.login.Configuration import com.github.dockerjava.api.model.{AccessMode, Bind, ContainerConfig, HostConfig, Volume} import org.apache.hadoop.security.{SecurityUtil, UserGroupInformation} import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod.KERBEROS -import org.scalatest.time.SpanSugar._ import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.execution.datasources.jdbc.connection.{DB2ConnectionProvider, SecureConnectionProvider} @@ -68,8 +67,6 @@ class DB2KrbIntegrationSuite extends DockerKrbJDBCIntegrationSuite { } } - override val connectionTimeout = timeout(3.minutes) - override protected def setAuthentication(keytabFile: String, principal: String): Unit = { val config = new SecureConnectionProvider.JDBCConfiguration( Configuration.getConfiguration, "JaasClient", keytabFile, principal, true) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index 496498e5455b4..1eee65986fccd 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -22,8 +22,6 @@ import java.sql.{Connection, Date, Timestamp} import java.time.{Duration, Period} import java.util.{Properties, TimeZone} -import org.scalatest.time.SpanSugar._ - import org.apache.spark.sql.{DataFrame, Row, SaveMode} import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ @@ -68,8 +66,6 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSpark override val db = new OracleDatabaseOnDocker - override val connectionTimeout = timeout(7.minutes) - private val rsOfTsWithTimezone = Seq( Row(BigDecimal.valueOf(1), new Timestamp(944046000000L)), Row(BigDecimal.valueOf(2), new Timestamp(944078400000L)) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala index 6c1b7fdd1be5a..3642094d11b29 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala @@ -20,8 +20,6 @@ package org.apache.spark.sql.jdbc.v2 import java.sql.Connection import java.util.Locale -import org.scalatest.time.SpanSugar._ - import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog @@ -52,7 +50,6 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override val catalogName: String = "db2" override val namespaceOpt: Option[String] = Some("DB2INST1") override val db = new DB2DatabaseOnDocker - override val connectionTimeout = timeout(3.minutes) override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.db2", classOf[JDBCTableCatalog].getName) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index 65f7579de8205..b1b8aec5ad337 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.jdbc.v2 import java.sql.Connection -import org.scalatest.time.SpanSugar._ - import org.apache.spark.{SparkConf, SparkSQLFeatureNotSupportedException} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog @@ -68,8 +66,6 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD .set("spark.sql.catalog.mssql.pushDownAggregate", "true") .set("spark.sql.catalog.mssql.pushDownLimit", "true") - override val connectionTimeout = timeout(7.minutes) - override def tablePreparation(connection: Connection): Unit = { connection.prepareStatement( "CREATE TABLE employee (dept INT, name VARCHAR(32), salary NUMERIC(20, 2), bonus FLOAT)") diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala index 4997d335fda6b..22900c7bbcc8b 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.jdbc.v2 import java.sql.{Connection, SQLFeatureNotSupportedException} -import org.scalatest.time.SpanSugar._ - import org.apache.spark.{SparkConf, SparkSQLFeatureNotSupportedException} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog @@ -68,8 +66,6 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest .set("spark.sql.catalog.mysql.pushDownLimit", "true") .set("spark.sql.catalog.mysql.pushDownOffset", "true") - override val connectionTimeout = timeout(7.minutes) - private var mySQLVersion = -1 override def tablePreparation(connection: Connection): Unit = { diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala index a011afac17720..b35018ec16dce 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala @@ -20,8 +20,6 @@ package org.apache.spark.sql.jdbc.v2 import java.sql.Connection import java.util.Locale -import org.scalatest.time.SpanSugar._ - import org.apache.spark.{SparkConf, SparkRuntimeException} import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.util.CharVarcharUtils.CHAR_VARCHAR_TYPE_STRING_METADATA_KEY @@ -91,8 +89,6 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes .set("spark.sql.catalog.oracle.pushDownLimit", "true") .set("spark.sql.catalog.oracle.pushDownOffset", "true") - override val connectionTimeout = timeout(7.minutes) - override def tablePreparation(connection: Connection): Unit = { connection.prepareStatement( "CREATE TABLE employee (dept NUMBER(32), name VARCHAR2(32), salary NUMBER(20, 2)," + From b47d7853d92f733791513094af04fc18ec947246 Mon Sep 17 00:00:00 2001 From: Eric Maynard Date: Fri, 10 May 2024 08:41:24 +0900 Subject: [PATCH 32/65] [SPARK-48148][CORE] JSON objects should not be modified when read as STRING ### What changes were proposed in this pull request? Currently, when reading a JSON like this: ``` {"a": {"b": -999.99999999999999999999999999999999995}} ``` With the schema: ``` a STRING ``` Spark will yield a result like this: ``` {"b": -1000.0} ``` Other changes such as changes to the input string's whitespace may also occur. In some cases, we apply scientific notation to an input floating-point number when reading it as STRING. This applies to reading JSON files (as with `spark.read.json`) as well as the SQL expression `from_json`. ### Why are the changes needed? Correctness issues may occur if a field is read as a STRING and then later parsed (e.g. with `from_json`) after the contents have been modified. ### Does this PR introduce _any_ user-facing change? Yes, when reading non-string fields from a JSON object using the STRING type, we will now extract the field exactly as it appears. ### How was this patch tested? Added a test in `JsonSuite.scala` ### Was this patch authored or co-authored using generative AI tooling? No Closes #46408 from eric-maynard/SPARK-48148. Lead-authored-by: Eric Maynard Co-authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../sql/catalyst/json/JacksonParser.scala | 63 ++++++++++++++++--- .../apache/spark/sql/internal/SQLConf.scala | 9 +++ .../datasources/json/JsonSuite.scala | 58 +++++++++++++++++ 3 files changed, 122 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 5e75ff6f6e1a3..b2c302fbbbe31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal import com.fasterxml.jackson.core._ +import org.apache.hadoop.fs.PositionedReadable import org.apache.spark.SparkUpgradeException import org.apache.spark.internal.Logging @@ -275,19 +276,63 @@ class JacksonParser( } } - case _: StringType => - (parser: JsonParser) => parseJsonToken[UTF8String](parser, dataType) { + case _: StringType => (parser: JsonParser) => { + // This must be enabled if we will retrieve the bytes directly from the raw content: + val includeSourceInLocation = JsonParser.Feature.INCLUDE_SOURCE_IN_LOCATION + val originalMask = if (includeSourceInLocation.enabledIn(parser.getFeatureMask)) { + 1 + } else { + 0 + } + parser.overrideStdFeatures(includeSourceInLocation.getMask, includeSourceInLocation.getMask) + val result = parseJsonToken[UTF8String](parser, dataType) { case VALUE_STRING => UTF8String.fromString(parser.getText) - case _ => + case other => // Note that it always tries to convert the data as string without the case of failure. - val writer = new ByteArrayOutputStream() - Utils.tryWithResource(factory.createGenerator(writer, JsonEncoding.UTF8)) { - generator => generator.copyCurrentStructure(parser) + val startLocation = parser.currentTokenLocation() + def skipAhead(): Unit = { + other match { + case START_OBJECT => + parser.skipChildren() + case START_ARRAY => + parser.skipChildren() + case _ => + // Do nothing in this case; we've already read the token + } } - UTF8String.fromBytes(writer.toByteArray) - } + + // PositionedReadable + startLocation.contentReference().getRawContent match { + case byteArray: Array[Byte] if exactStringParsing => + skipAhead() + val endLocation = parser.currentLocation.getByteOffset + + UTF8String.fromBytes( + byteArray, + startLocation.getByteOffset.toInt, + endLocation.toInt - (startLocation.getByteOffset.toInt)) + case positionedReadable: PositionedReadable if exactStringParsing => + skipAhead() + val endLocation = parser.currentLocation.getByteOffset + + val size = endLocation.toInt - (startLocation.getByteOffset.toInt) + val buffer = new Array[Byte](size) + positionedReadable.read(startLocation.getByteOffset, buffer, 0, size) + UTF8String.fromBytes(buffer, 0, size) + case _ => + val writer = new ByteArrayOutputStream() + Utils.tryWithResource(factory.createGenerator(writer, JsonEncoding.UTF8)) { + generator => generator.copyCurrentStructure(parser) + } + UTF8String.fromBytes(writer.toByteArray) + } + } + // Reset back to the original configuration: + parser.overrideStdFeatures(includeSourceInLocation.getMask, originalMask) + result + } case TimestampType => (parser: JsonParser) => parseJsonToken[java.lang.Long](parser, dataType) { @@ -429,6 +474,8 @@ class JacksonParser( private val allowEmptyString = SQLConf.get.getConf(SQLConf.LEGACY_ALLOW_EMPTY_STRING_IN_JSON) + private val exactStringParsing = SQLConf.get.getConf(SQLConf.JSON_EXACT_STRING_PARSING) + /** * This function throws an exception for failed conversion. For empty string on data types * except for string and binary types, this also throws an exception. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 54aa87260534f..e78157d611586 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4257,6 +4257,15 @@ object SQLConf { .booleanConf .createWithDefault(true) + val JSON_EXACT_STRING_PARSING = + buildConf("spark.sql.json.enableExactStringParsing") + .internal() + .doc("When set to true, string columns extracted from JSON objects will be extracted " + + "exactly as they appear in the input string, with no changes") + .version("4.0.0") + .booleanConf + .createWithDefault(true) + val LEGACY_CSV_ENABLE_DATE_TIME_PARSING_FALLBACK = buildConf("spark.sql.legacy.csv.enableDateTimeParsingFallback") .internal() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index c17a25be8e2ae..3d0eedd2f689c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -3865,6 +3865,64 @@ abstract class JsonSuite } } } + + test("SPARK-48148: values are unchanged when read as string") { + withTempPath { path => + def extractData( + jsonString: String, + expectedInexactData: Seq[String], + expectedExactData: Seq[String], + multiLine: Boolean = false): Unit = { + Seq(jsonString).toDF() + .repartition(1) + .write + .mode("overwrite") + .text(path.getAbsolutePath) + + withClue("Exact string parsing") { + withSQLConf(SQLConf.JSON_EXACT_STRING_PARSING.key -> "true") { + val df = spark.read + .schema("data STRING") + .option("multiLine", multiLine.toString) + .json(path.getAbsolutePath) + checkAnswer(df, expectedExactData.map(d => Row(d))) + } + } + + withClue("Inexact string parsing") { + withSQLConf(SQLConf.JSON_EXACT_STRING_PARSING.key -> "false") { + val df = spark.read + .schema("data STRING") + .option("multiLine", multiLine.toString) + .json(path.getAbsolutePath) + checkAnswer(df, expectedInexactData.map(d => Row(d))) + } + } + } + extractData( + """{"data": {"white": "space"}}""", + expectedInexactData = Seq("""{"white":"space"}"""), + expectedExactData = Seq("""{"white": "space"}""") + ) + extractData( + """{"data": ["white", "space"]}""", + expectedInexactData = Seq("""["white","space"]"""), + expectedExactData = Seq("""["white", "space"]""") + ) + val granularFloat = "-999.99999999999999999999999999999999995" + extractData( + s"""{"data": {"v": ${granularFloat}}}""", + expectedInexactData = Seq("""{"v":-1000.0}"""), + expectedExactData = Seq(s"""{"v": ${granularFloat}}""") + ) + extractData( + s"""{"data": {"white":\n"space"}}""", + expectedInexactData = Seq("""{"white":"space"}"""), + expectedExactData = Seq(s"""{"white":\n"space"}"""), + multiLine = true + ) + } + } } class JsonV1Suite extends JsonSuite { From e704b9e56b0cc862ebd5c95b9d023ab0a5ffdba7 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Thu, 9 May 2024 16:42:58 -0700 Subject: [PATCH 33/65] [SPARK-48226][BUILD] Add `spark-ganglia-lgpl` to `lint-java` & `spark-ganglia-lgpl` and `jvm-profiler` to `sbt-checkstyle` ### What changes were proposed in this pull request? The pr aims to add - `spark-ganglia-lgpl` to `lint-java` - `spark-ganglia-lgpl` and `jvm-profiler` to `sbt-checkstyle` ### Why are the changes needed? 1.Because the module `spark-ganglia-lgpl` has `java` code 2.Because the module `spark-ganglia-lgpl` & `jvm-profiler` has `scala` code 3.Although these module codes currently comply with the specification, in order to avoid problems like https://github.com/apache/spark/pull/46376, they will occur again in future modifications. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - Manually test. - Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46501 from panbingkun/minor_spark-ganglia-lgpl. Authored-by: panbingkun Signed-off-by: Dongjoon Hyun --- dev/lint-java | 2 +- dev/sbt-checkstyle | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dev/lint-java b/dev/lint-java index ac5a2c869404f..ff431301773f3 100755 --- a/dev/lint-java +++ b/dev/lint-java @@ -20,7 +20,7 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)" -ERRORS=$($SCRIPT_DIR/../build/mvn -Pkinesis-asl -Pkubernetes -Pyarn -Phive -Phive-thriftserver checkstyle:check | grep ERROR) +ERRORS=$($SCRIPT_DIR/../build/mvn -Pkinesis-asl -Pspark-ganglia-lgpl -Pkubernetes -Pyarn -Phive -Phive-thriftserver checkstyle:check | grep ERROR) if test ! -z "$ERRORS"; then echo -e "Checkstyle checks failed at following occurrences:\n$ERRORS" diff --git a/dev/sbt-checkstyle b/dev/sbt-checkstyle index 99a46a3a0e38b..f2d5a0fa304ac 100755 --- a/dev/sbt-checkstyle +++ b/dev/sbt-checkstyle @@ -17,7 +17,7 @@ # limitations under the License. # -SPARK_PROFILES=${1:-"-Pkinesis-asl -Pkubernetes -Pyarn -Phive -Phive-thriftserver"} +SPARK_PROFILES=${1:-"-Pkinesis-asl -Pspark-ganglia-lgpl -Pkubernetes -Pyarn -Phive -Phive-thriftserver -Pjvm-profiler"} # NOTE: echo "q" is needed because SBT prompts the user for input on encountering a build file # with failure (either resolution or compilation); the "q" makes SBT quit. From 71f0eda71bc169a5245f4412ec0957728025a66c Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Fri, 10 May 2024 08:44:07 +0900 Subject: [PATCH 34/65] [SPARK-48180][SQL] Improve error when UDTF call with TABLE arg forgets parentheses around multiple PARTITION/ORDER BY exprs ### What changes were proposed in this pull request? This PR improves the error message when a table-valued function call has a TABLE argument with a PARTITION BY or ORDER BY clause with more than one associated expression, but forgets parentheses around them. For example: ``` SELECT * FROM testUDTF( TABLE(SELECT 1 AS device_id, 2 AS data_ds) WITH SINGLE PARTITION ORDER BY device_id, data_ds) ``` This query previously returned an obscure, unrelated error: ``` [UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_TABLE_ARGUMENT] Unsupported subquery expression: Table arguments are used in a function where they are not supported: 'UnresolvedTableValuedFunction [tvf], [table-argument#338 [], 'data_ds](https://issues.apache.org/jira/browse/SPARK-48180#338%20[],%20'data_ds), false +- Project [1 AS device_id#336, 2 AS data_ds#337](https://issues.apache.org/jira/browse/SPARK-48180#336,%202%20AS%20data_ds#337) +- OneRowRelation ``` Now it returns a reasonable error: ``` The table function call includes a table argument with an invalid partitioning/ordering specification: the ORDER BY clause included multiple expressions without parentheses surrounding them; please add parentheses around these expressions and then retry the query again. (line 4, pos 2) == SQL == SELECT * FROM testUDTF( TABLE(SELECT 1 AS device_id, 2 AS data_ds) WITH SINGLE PARTITION --^^^ ORDER BY device_id, data_ds) ``` ### Why are the changes needed? Here we improve error messages for common SQL syntax mistakes. ### Does this PR introduce _any_ user-facing change? Yes, see above. ### How was this patch tested? This PR adds test coverage. ### Was this patch authored or co-authored using generative AI tooling? No Closes #46451 from dtenedor/udtf-analyzer-bug. Authored-by: Daniel Tenedorio Signed-off-by: Hyukjin Kwon --- .../sql/catalyst/parser/SqlBaseParser.g4 | 2 ++ .../sql/catalyst/parser/AstBuilder.scala | 14 ++++++++++ .../sql/catalyst/parser/PlanParserSuite.scala | 19 ++++++++++---- .../execution/python/PythonUDTFSuite.scala | 26 +++++++++++++++++++ 4 files changed, 56 insertions(+), 5 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 653224c5475f8..249f55fa40ac5 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -838,9 +838,11 @@ tableArgumentPartitioning : ((WITH SINGLE PARTITION) | ((PARTITION | DISTRIBUTE) BY (((LEFT_PAREN partition+=expression (COMMA partition+=expression)* RIGHT_PAREN)) + | (expression (COMMA invalidMultiPartitionExpression=expression)+) | partition+=expression))) ((ORDER | SORT) BY (((LEFT_PAREN sortItem (COMMA sortItem)* RIGHT_PAREN) + | (sortItem (COMMA invalidMultiSortItem=sortItem)+) | sortItem)))? ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 7d2355b2f08d1..326f1e7684b9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1638,6 +1638,20 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { } partitionByExpressions = p.partition.asScala.map(expression).toSeq orderByExpressions = p.sortItem.asScala.map(visitSortItem).toSeq + def invalidPartitionOrOrderingExpression(clause: String): String = { + "The table function call includes a table argument with an invalid " + + s"partitioning/ordering specification: the $clause clause included multiple " + + "expressions without parentheses surrounding them; please add parentheses around " + + "these expressions and then retry the query again" + } + validate( + Option(p.invalidMultiPartitionExpression).isEmpty, + message = invalidPartitionOrOrderingExpression("PARTITION BY"), + ctx = p.invalidMultiPartitionExpression) + validate( + Option(p.invalidMultiSortItem).isEmpty, + message = invalidPartitionOrOrderingExpression("ORDER BY"), + ctx = p.invalidMultiSortItem) } validate( !(withSinglePartition && partitionByExpressions.nonEmpty), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 17dd7349e7bea..8d01040563361 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -1617,14 +1617,23 @@ class PlanParserSuite extends AnalysisTest { parameters = Map( "error" -> "'order'", "hint" -> "")) - val sql8 = s"select * from my_tvf(arg1 => table(select col1, col2, col3 from v2) " + - s"$partition by col1, col2 order by col2 asc, col3 desc)" + val sql8tableArg = "table(select col1, col2, col3 from v2)" + val sql8partition = s"$partition by col1, col2 order by col2 asc, col3 desc" + val sql8 = s"select * from my_tvf(arg1 => $sql8tableArg $sql8partition)" checkError( exception = parseException(sql8), - errorClass = "PARSE_SYNTAX_ERROR", + errorClass = "_LEGACY_ERROR_TEMP_0064", parameters = Map( - "error" -> "'order'", - "hint" -> ": extra input 'order'")) + "msg" -> + ("The table function call includes a table argument with an invalid " + + "partitioning/ordering specification: the PARTITION BY clause included multiple " + + "expressions without parentheses surrounding them; please add parentheses around " + + "these expressions and then retry the query again")), + context = ExpectedContext( + fragment = s"$sql8tableArg $sql8partition", + start = 29, + stop = 110 + partition.length) + ) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala index 989597ae041db..1eaf1d24056da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.{Add, Alias, Expression, FunctionTableSubqueryArgumentExpression, Literal} +import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, OneRowRelation, Project, Repartition, RepartitionByExpression, Sort, SubqueryAlias} import org.apache.spark.sql.functions.lit import org.apache.spark.sql.internal.SQLConf @@ -363,4 +364,29 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { Row("abc")) } } + + test("SPARK-48180: Analyzer bug with multiple ORDER BY items for input table argument") { + assume(shouldTestPythonUDFs) + spark.udtf.registerPython("testUDTF", pythonUDTF) + checkError( + exception = intercept[ParseException](sql( + """ + |SELECT * FROM testUDTF( + | TABLE(SELECT 1 AS device_id, 2 AS data_ds) + | WITH SINGLE PARTITION + | ORDER BY device_id, data_ds) + |""".stripMargin)), + errorClass = "_LEGACY_ERROR_TEMP_0064", + parameters = Map("msg" -> + ("The table function call includes a table argument with an invalid " + + "partitioning/ordering specification: the ORDER BY clause included multiple " + + "expressions without parentheses surrounding them; please add parentheses around these " + + "expressions and then retry the query again")), + context = ExpectedContext( + fragment = "TABLE(SELECT 1 AS device_id, 2 AS data_ds)\n " + + "WITH SINGLE PARTITION\n " + + "ORDER BY device_id, data_ds", + start = 27, + stop = 122)) + } } From 012d19d8e9b28f7ce266753bcfff4a76c9510245 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 9 May 2024 16:58:44 -0700 Subject: [PATCH 35/65] [SPARK-48227][PYTHON][DOC] Document the requirement of seed in protos ### What changes were proposed in this pull request? Document the requirement of seed in protos ### Why are the changes needed? the seed should be set at client side document it to avoid cases like https://github.com/apache/spark/pull/46456 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #46518 from zhengruifeng/doc_random. Authored-by: Ruifeng Zheng Signed-off-by: Dongjoon Hyun --- .../src/main/protobuf/spark/connect/relations.proto | 8 ++++++-- python/pyspark/sql/connect/plan.py | 10 ++++------ python/pyspark/sql/connect/proto/relations_pb2.pyi | 10 ++++++++-- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto index 3882b2e853967..0b3c9d4253e8c 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto @@ -467,7 +467,9 @@ message Sample { // (Optional) Whether to sample with replacement. optional bool with_replacement = 4; - // (Optional) The random seed. + // (Required) The random seed. + // This filed is required to avoid generate mutable dataframes (see SPARK-48184 for details), + // however, still keep it 'optional' here for backward compatibility. optional int64 seed = 5; // (Required) Explicitly sort the underlying plan to make the ordering deterministic or cache it. @@ -687,7 +689,9 @@ message StatSampleBy { // If a stratum is not specified, we treat its fraction as zero. repeated Fraction fractions = 3; - // (Optional) The random seed. + // (Required) The random seed. + // This filed is required to avoid generate mutable dataframes (see SPARK-48184 for details), + // however, still keep it 'optional' here for backward compatibility. optional int64 seed = 5; message Fraction { diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 4ac4946745f5e..3d3303fb15c57 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -717,7 +717,7 @@ def __init__( lower_bound: float, upper_bound: float, with_replacement: bool, - seed: Optional[int], + seed: int, deterministic_order: bool = False, ) -> None: super().__init__(child) @@ -734,8 +734,7 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation: plan.sample.lower_bound = self.lower_bound plan.sample.upper_bound = self.upper_bound plan.sample.with_replacement = self.with_replacement - if self.seed is not None: - plan.sample.seed = self.seed + plan.sample.seed = self.seed plan.sample.deterministic_order = self.deterministic_order return plan @@ -1526,7 +1525,7 @@ def __init__( child: Optional["LogicalPlan"], col: Column, fractions: Sequence[Tuple[Column, float]], - seed: Optional[int], + seed: int, ) -> None: super().__init__(child) @@ -1554,8 +1553,7 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation: fraction.stratum.CopyFrom(k.to_plan(session).literal) fraction.fraction = float(v) plan.sample_by.fractions.append(fraction) - if self._seed is not None: - plan.sample_by.seed = self._seed + plan.sample_by.seed = self._seed return plan diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index 5dfb47da67a97..9b6f4b43544f2 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -1865,7 +1865,10 @@ class Sample(google.protobuf.message.Message): with_replacement: builtins.bool """(Optional) Whether to sample with replacement.""" seed: builtins.int - """(Optional) The random seed.""" + """(Required) The random seed. + This filed is required to avoid generate mutable dataframes (see SPARK-48184 for details), + however, still keep it 'optional' here for backward compatibility. + """ deterministic_order: builtins.bool """(Required) Explicitly sort the underlying plan to make the ordering deterministic or cache it. This flag is true when invoking `dataframe.randomSplit` to randomly splits DataFrame with the @@ -2545,7 +2548,10 @@ class StatSampleBy(google.protobuf.message.Message): If a stratum is not specified, we treat its fraction as zero. """ seed: builtins.int - """(Optional) The random seed.""" + """(Required) The random seed. + This filed is required to avoid generate mutable dataframes (see SPARK-48184 for details), + however, still keep it 'optional' here for backward compatibility. + """ def __init__( self, *, From 9a2818820f11f9bdcc042f4ab80850918911c68c Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Fri, 10 May 2024 09:58:16 +0800 Subject: [PATCH 36/65] [SPARK-48222][INFRA][DOCS] Sync Ruby Bundler to 2.4.22 and refresh Gem lock file ### What changes were proposed in this pull request? Sync the version of Bundler that we are using across various scripts and documentation. Also refresh the Gem lock file. ### Why are the changes needed? We are seeing inconsistent build behavior, likely due to the inconsistent Bundler versions. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI + the preview release process. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46512 from nchammas/bundler-sync. Authored-by: Nicholas Chammas Signed-off-by: Wenchen Fan --- .github/workflows/build_and_test.yml | 3 +++ dev/create-release/spark-rm/Dockerfile | 2 +- docs/Gemfile.lock | 16 ++++++++-------- docs/README.md | 2 +- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 4a11823aee604..881fb8cb06745 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -872,6 +872,9 @@ jobs: python3.9 -m pip install 'docutils<0.18.0' # See SPARK-39421 - name: Install dependencies for documentation generation run: | + # Keep the version of Bundler here in sync with the following locations: + # - dev/create-release/spark-rm/Dockerfile + # - docs/README.md gem install bundler -v 2.4.22 cd docs bundle install diff --git a/dev/create-release/spark-rm/Dockerfile b/dev/create-release/spark-rm/Dockerfile index 8d5ca38ba88ee..13f4112ca03da 100644 --- a/dev/create-release/spark-rm/Dockerfile +++ b/dev/create-release/spark-rm/Dockerfile @@ -38,7 +38,7 @@ ENV DEBCONF_NONINTERACTIVE_SEEN true ARG APT_INSTALL="apt-get install --no-install-recommends -y" ARG PIP_PKGS="sphinx==4.5.0 mkdocs==1.1.2 numpy==1.20.3 pydata_sphinx_theme==0.13.3 ipython==7.19.0 nbsphinx==0.8.0 numpydoc==1.1.0 jinja2==3.1.2 twine==3.4.1 sphinx-plotly-directive==0.1.3 sphinx-copybutton==0.5.2 pandas==2.0.3 pyarrow==10.0.1 plotly==5.4.0 markupsafe==2.0.1 docutils<0.17 grpcio==1.62.0 protobuf==4.21.6 grpcio-status==1.62.0 googleapis-common-protos==1.56.4" -ARG GEM_PKGS="bundler:2.3.8" +ARG GEM_PKGS="bundler:2.4.22" # Install extra needed repos and refresh. # - CRAN repo diff --git a/docs/Gemfile.lock b/docs/Gemfile.lock index 4e38f18703f3c..e137f0f039b97 100644 --- a/docs/Gemfile.lock +++ b/docs/Gemfile.lock @@ -4,16 +4,16 @@ GEM addressable (2.8.6) public_suffix (>= 2.0.2, < 6.0) colorator (1.1.0) - concurrent-ruby (1.2.2) + concurrent-ruby (1.2.3) em-websocket (0.5.3) eventmachine (>= 0.12.9) http_parser.rb (~> 0) eventmachine (1.2.7) ffi (1.16.3) forwardable-extended (2.6.0) - google-protobuf (3.25.2) + google-protobuf (3.25.3) http_parser.rb (0.8.0) - i18n (1.14.1) + i18n (1.14.5) concurrent-ruby (~> 1.0) jekyll (4.3.3) addressable (~> 2.4) @@ -42,22 +42,22 @@ GEM kramdown-parser-gfm (1.1.0) kramdown (~> 2.0) liquid (4.0.4) - listen (3.8.0) + listen (3.9.0) rb-fsevent (~> 0.10, >= 0.10.3) rb-inotify (~> 0.9, >= 0.9.10) mercenary (0.4.0) pathutil (0.16.2) forwardable-extended (~> 2.6) - public_suffix (5.0.4) - rake (13.1.0) + public_suffix (5.0.5) + rake (13.2.1) rb-fsevent (0.11.2) rb-inotify (0.10.1) ffi (~> 1.0) rexml (3.2.6) rouge (3.30.0) safe_yaml (1.0.5) - sass-embedded (1.69.7) - google-protobuf (~> 3.25) + sass-embedded (1.63.6) + google-protobuf (~> 3.23) rake (>= 13.0.0) terminal-table (3.0.2) unicode-display_width (>= 1.1.1, < 3) diff --git a/docs/README.md b/docs/README.md index 414c8dbd83035..363f1c2076363 100644 --- a/docs/README.md +++ b/docs/README.md @@ -36,7 +36,7 @@ You need to have [Ruby 3][ruby] and [Python 3][python] installed. Make sure the [python]: https://www.python.org/downloads/ ```sh -$ gem install bundler +$ gem install bundler -v 2.4.22 ``` After this all the required Ruby dependencies can be installed from the `docs/` directory via Bundler: From a41d0ae79b432e2757379fc56a0ad2755f02e871 Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Fri, 10 May 2024 12:23:34 +0900 Subject: [PATCH 37/65] [SPARK-48176][SQL] Adjust name of FIELD_ALREADY_EXISTS error condition ### What changes were proposed in this pull request? Rename `FIELDS_ALREADY_EXISTS` to `FIELD_ALREADY_EXISTS`. ### Why are the changes needed? Though it's not meant to be a proper English sentence, `FIELDS_ALREADY_EXISTS` is grammatically incorrect. It should either be "fields already exist[]" or "field[] already exists". I opted for the latter. ### Does this PR introduce _any_ user-facing change? Yes, it changes the name of an error condition. ### How was this patch tested? CI only. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46510 from nchammas/SPARK-48176-field-exists-error. Authored-by: Nicholas Chammas Signed-off-by: Hyukjin Kwon --- common/utils/src/main/resources/error/error-conditions.json | 2 +- .../test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala | 4 ++-- .../apache/spark/sql/catalyst/analysis/CheckAnalysis.scala | 2 +- .../org/apache/spark/sql/connector/AlterTableTests.scala | 4 ++-- .../spark/sql/connector/V2CommandsCaseSensitivitySuite.scala | 4 ++-- .../execution/datasources/v2/jdbc/JDBCTableCatalogSuite.scala | 4 ++-- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 8a64c4c590e8a..7c9886c749b95 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -1339,7 +1339,7 @@ ], "sqlState" : "54001" }, - "FIELDS_ALREADY_EXISTS" : { + "FIELD_ALREADY_EXISTS" : { "message" : [ "Cannot column, because already exists in ." ], diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala index c80fbfc748dd1..b60107f902839 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala @@ -107,7 +107,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu exception = intercept[AnalysisException] { sql(s"ALTER TABLE $catalogName.alt_table ADD COLUMNS (C3 DOUBLE)") }, - errorClass = "FIELDS_ALREADY_EXISTS", + errorClass = "FIELD_ALREADY_EXISTS", parameters = Map( "op" -> "add", "fieldNames" -> "`C3`", @@ -179,7 +179,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu exception = intercept[AnalysisException] { sql(s"ALTER TABLE $catalogName.alt_table RENAME COLUMN ID1 TO ID2") }, - errorClass = "FIELDS_ALREADY_EXISTS", + errorClass = "FIELD_ALREADY_EXISTS", parameters = Map( "op" -> "rename", "fieldNames" -> "`ID2`", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index e55f23b6aa86e..e18f4d1b36e1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -1403,7 +1403,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB if (struct.findNestedField( fieldNames, includeCollections = true, alter.conf.resolver).isDefined) { alter.failAnalysis( - errorClass = "FIELDS_ALREADY_EXISTS", + errorClass = "FIELD_ALREADY_EXISTS", messageParameters = Map( "op" -> op, "fieldNames" -> toSQLId(fieldNames), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala index 996d7acb1148d..28605958c71da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala @@ -466,7 +466,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { exception = intercept[AnalysisException] { sql(s"ALTER TABLE $t ADD COLUMNS $field double") }, - errorClass = "FIELDS_ALREADY_EXISTS", + errorClass = "FIELD_ALREADY_EXISTS", parameters = expectedParameters, context = ExpectedContext( fragment = s"ALTER TABLE $t ADD COLUMNS $field double", @@ -1116,7 +1116,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { exception = intercept[AnalysisException] { sql(s"ALTER TABLE $t RENAME COLUMN $field TO $newName") }, - errorClass = "FIELDS_ALREADY_EXISTS", + errorClass = "FIELD_ALREADY_EXISTS", parameters = Map( "op" -> "rename", "fieldNames" -> s"${toSQLId(expectedName)}", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala index ee71bd3af1e02..3ab7edb78439c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala @@ -306,7 +306,7 @@ class V2CommandsCaseSensitivitySuite None, Some(UnresolvedFieldPosition(ColumnPosition.after("id"))), None))), - "FIELDS_ALREADY_EXISTS", + "FIELD_ALREADY_EXISTS", Map( "op" -> "add", "fieldNames" -> "`ID`", @@ -317,7 +317,7 @@ class V2CommandsCaseSensitivitySuite test("SPARK-36381: Check column name exist case sensitive and insensitive when rename column") { alterTableErrorClass( RenameColumn(table, UnresolvedFieldName(Array("id").toImmutableArraySeq), "DATA"), - "FIELDS_ALREADY_EXISTS", + "FIELD_ALREADY_EXISTS", Map( "op" -> "rename", "fieldNames" -> "`DATA`", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalogSuite.scala index f4e7921e88bc2..daf5d8507ecc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalogSuite.scala @@ -200,7 +200,7 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql(s"ALTER TABLE $tableName ADD COLUMNS (c3 DOUBLE)") }, - errorClass = "FIELDS_ALREADY_EXISTS", + errorClass = "FIELD_ALREADY_EXISTS", parameters = Map( "op" -> "add", "fieldNames" -> "`c3`", @@ -239,7 +239,7 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql(s"ALTER TABLE $tableName RENAME COLUMN C TO C0") }, - errorClass = "FIELDS_ALREADY_EXISTS", + errorClass = "FIELD_ALREADY_EXISTS", parameters = Map( "op" -> "rename", "fieldNames" -> "`C0`", From 32b2827b964bd4a4accb60b47ddd6929f41d4a89 Mon Sep 17 00:00:00 2001 From: YangJie Date: Thu, 9 May 2024 20:47:34 -0700 Subject: [PATCH 38/65] [SPARK-47834][SQL][CONNECT] Mark deprecated functions with `@deprecated` in `SQLImplicits` ### What changes were proposed in this pull request? In the `sql` module, some functions in `SQLImplicits` have already been marked as `deprecated` in the function comments after SPARK-19089. This pr adds `deprecated` type annotation marks to them. Since SPARK-19089 occurred in Spark 2.2.0, the `since` field of `deprecated` is filled in as `2.2.0`. At the same time, these `deprecated` marks have also been synchronized to the corresponding functions in `SQLImplicits` in the `connect` module. ### Why are the changes needed? Mark deprecated functions with `deprecated` in `SQLImplicits` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass Github Actions ### Was this patch authored or co-authored using generative AI tooling? No Closes #46029 from LuciferYang/deprecated-SQLImplicits. Lead-authored-by: YangJie Co-authored-by: yangjie01 Signed-off-by: Dongjoon Hyun --- .../main/scala/org/apache/spark/sql/SQLImplicits.scala | 9 +++++++++ .../main/scala/org/apache/spark/sql/SQLImplicits.scala | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 6c626fd716d5b..7799d395d5c6a 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -149,6 +149,7 @@ abstract class SQLImplicits private[sql] (session: SparkSession) extends LowPrio * @deprecated * use [[newSequenceEncoder]] */ + @deprecated("Use newSequenceEncoder instead", "2.2.0") val newIntSeqEncoder: Encoder[Seq[Int]] = newSeqEncoder(PrimitiveIntEncoder) /** @@ -156,6 +157,7 @@ abstract class SQLImplicits private[sql] (session: SparkSession) extends LowPrio * @deprecated * use [[newSequenceEncoder]] */ + @deprecated("Use newSequenceEncoder instead", "2.2.0") val newLongSeqEncoder: Encoder[Seq[Long]] = newSeqEncoder(PrimitiveLongEncoder) /** @@ -163,6 +165,7 @@ abstract class SQLImplicits private[sql] (session: SparkSession) extends LowPrio * @deprecated * use [[newSequenceEncoder]] */ + @deprecated("Use newSequenceEncoder instead", "2.2.0") val newDoubleSeqEncoder: Encoder[Seq[Double]] = newSeqEncoder(PrimitiveDoubleEncoder) /** @@ -170,6 +173,7 @@ abstract class SQLImplicits private[sql] (session: SparkSession) extends LowPrio * @deprecated * use [[newSequenceEncoder]] */ + @deprecated("Use newSequenceEncoder instead", "2.2.0") val newFloatSeqEncoder: Encoder[Seq[Float]] = newSeqEncoder(PrimitiveFloatEncoder) /** @@ -177,6 +181,7 @@ abstract class SQLImplicits private[sql] (session: SparkSession) extends LowPrio * @deprecated * use [[newSequenceEncoder]] */ + @deprecated("Use newSequenceEncoder instead", "2.2.0") val newByteSeqEncoder: Encoder[Seq[Byte]] = newSeqEncoder(PrimitiveByteEncoder) /** @@ -184,6 +189,7 @@ abstract class SQLImplicits private[sql] (session: SparkSession) extends LowPrio * @deprecated * use [[newSequenceEncoder]] */ + @deprecated("Use newSequenceEncoder instead", "2.2.0") val newShortSeqEncoder: Encoder[Seq[Short]] = newSeqEncoder(PrimitiveShortEncoder) /** @@ -191,6 +197,7 @@ abstract class SQLImplicits private[sql] (session: SparkSession) extends LowPrio * @deprecated * use [[newSequenceEncoder]] */ + @deprecated("Use newSequenceEncoder instead", "2.2.0") val newBooleanSeqEncoder: Encoder[Seq[Boolean]] = newSeqEncoder(PrimitiveBooleanEncoder) /** @@ -198,6 +205,7 @@ abstract class SQLImplicits private[sql] (session: SparkSession) extends LowPrio * @deprecated * use [[newSequenceEncoder]] */ + @deprecated("Use newSequenceEncoder instead", "2.2.0") val newStringSeqEncoder: Encoder[Seq[String]] = newSeqEncoder(StringEncoder) /** @@ -205,6 +213,7 @@ abstract class SQLImplicits private[sql] (session: SparkSession) extends LowPrio * @deprecated * use [[newSequenceEncoder]] */ + @deprecated("Use newSequenceEncoder instead", "2.2.0") def newProductSeqEncoder[A <: Product: TypeTag]: Encoder[Seq[A]] = newSeqEncoder(ScalaReflection.encoderFor[A]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index d257a6b771b93..56f13994277d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -130,54 +130,63 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits { * @since 1.6.1 * @deprecated use [[newSequenceEncoder]] */ + @deprecated("Use newSequenceEncoder instead", "2.2.0") def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder() /** * @since 1.6.1 * @deprecated use [[newSequenceEncoder]] */ + @deprecated("Use newSequenceEncoder instead", "2.2.0") def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder() /** * @since 1.6.1 * @deprecated use [[newSequenceEncoder]] */ + @deprecated("Use newSequenceEncoder instead", "2.2.0") def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder() /** * @since 1.6.1 * @deprecated use [[newSequenceEncoder]] */ + @deprecated("Use newSequenceEncoder instead", "2.2.0") def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder() /** * @since 1.6.1 * @deprecated use [[newSequenceEncoder]] */ + @deprecated("Use newSequenceEncoder instead", "2.2.0") def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder() /** * @since 1.6.1 * @deprecated use [[newSequenceEncoder]] */ + @deprecated("Use newSequenceEncoder instead", "2.2.0") def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder() /** * @since 1.6.1 * @deprecated use [[newSequenceEncoder]] */ + @deprecated("Use newSequenceEncoder instead", "2.2.0") def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder() /** * @since 1.6.1 * @deprecated use [[newSequenceEncoder]] */ + @deprecated("Use newSequenceEncoder instead", "2.2.0") def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder() /** * @since 1.6.1 * @deprecated use [[newSequenceEncoder]] */ + @deprecated("Use newSequenceEncoder instead", "2.2.0") def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder() /** @since 2.2.0 */ From 1138b2a68b5408e6d079bdbce8026323694628e5 Mon Sep 17 00:00:00 2001 From: zml1206 Date: Thu, 9 May 2024 20:51:32 -0700 Subject: [PATCH 39/65] [MINOR][BUILD] Remove duplicate configuration of maven-compiler-plugin ### What changes were proposed in this pull request? `${java.version}` and `${java.version}` (https://github.com/apache/spark/pull/46024/files#diff-9c5fb3d1b7e3b0f54bc5c4182965c4fe1f9023d449017cece3005d3f90e8e4d8R117) are equivalent duplicate configuration, so remove `${java.version}`. https://maven.apache.org/plugins/maven-compiler-plugin/examples/set-compiler-release.html ### Why are the changes needed? Simplify the code and facilitates subsequent configuration iterations. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46024 from zml1206/remove_duplicate_configuration. Authored-by: zml1206 Signed-off-by: Dongjoon Hyun --- pom.xml | 1 - 1 file changed, 1 deletion(-) diff --git a/pom.xml b/pom.xml index c3ff5d101c224..678455e6e2482 100644 --- a/pom.xml +++ b/pom.xml @@ -3127,7 +3127,6 @@ maven-compiler-plugin 3.13.0 - ${java.version} true true From 2d609bfd37ae9a0877fb72d1ba0479bb04a2dad6 Mon Sep 17 00:00:00 2001 From: Cheng Pan Date: Thu, 9 May 2024 21:31:50 -0700 Subject: [PATCH 40/65] [SPARK-47018][BUILD][SQL] Bump built-in Hive to 2.3.10 ### What changes were proposed in this pull request? This PR aims to bump Spark's built-in Hive from 2.3.9 to Hive 2.3.10, with two additional changes: - due to API breaking changes of Thrift, `libthrift` is upgraded from `0.12` to `0.16`. - remove version management of `commons-lang:2.6`, it comes from Hive transitive deps, Hive 2.3.10 drops it in https://github.com/apache/hive/pull/4892 This is the first part of https://github.com/apache/spark/pull/45372 ### Why are the changes needed? Bump Hive to the latest version of 2.3, prepare for upgrading Guava, and dropping vulnerable dependencies like Jackson 1.x / Jodd ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. (wait for sunchao to complete the 2.3.10 release to make jars visible on Maven Central) ### Was this patch authored or co-authored using generative AI tooling? No. Closes #45372 Closes #46468 from pan3793/SPARK-47018. Lead-authored-by: Cheng Pan Co-authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- connector/kafka-0-10-assembly/pom.xml | 5 --- connector/kinesis-asl-assembly/pom.xml | 5 --- dev/deps/spark-deps-hadoop-3-hive-2.3 | 27 ++++++++-------- docs/building-spark.md | 4 +-- docs/sql-data-sources-hive-tables.md | 8 ++--- docs/sql-migration-guide.md | 2 +- pom.xml | 31 ++++++++----------- .../hive/service/auth/KerberosSaslHelper.java | 5 +-- .../hive/service/auth/PlainSaslHelper.java | 3 +- .../service/auth/TSetIpAddressProcessor.java | 5 +-- .../cli/thrift/ThriftBinaryCLIService.java | 6 ---- .../service/cli/thrift/ThriftCLIService.java | 10 ++++++ .../org/apache/spark/sql/hive/HiveUtils.scala | 2 +- .../spark/sql/hive/client/package.scala | 5 ++- .../HiveExternalCatalogVersionsSuite.scala | 1 - .../spark/sql/hive/HiveSparkSubmitSuite.scala | 10 +++--- .../sql/hive/execution/HiveQuerySuite.scala | 6 ++-- 17 files changed, 61 insertions(+), 74 deletions(-) diff --git a/connector/kafka-0-10-assembly/pom.xml b/connector/kafka-0-10-assembly/pom.xml index b2fcbdf8eca7d..bd311b3a98047 100644 --- a/connector/kafka-0-10-assembly/pom.xml +++ b/connector/kafka-0-10-assembly/pom.xml @@ -54,11 +54,6 @@ commons-codec provided - - commons-lang - commons-lang - provided - com.google.protobuf protobuf-java diff --git a/connector/kinesis-asl-assembly/pom.xml b/connector/kinesis-asl-assembly/pom.xml index 577ec21530837..0e93526fce721 100644 --- a/connector/kinesis-asl-assembly/pom.xml +++ b/connector/kinesis-asl-assembly/pom.xml @@ -54,11 +54,6 @@ jackson-databind provided - - commons-lang - commons-lang - provided - org.glassfish.jersey.core jersey-client diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 73d41e9eeb337..392bacd73277f 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -46,7 +46,6 @@ commons-compress/1.26.1//commons-compress-1.26.1.jar commons-crypto/1.1.0//commons-crypto-1.1.0.jar commons-dbcp/1.4//commons-dbcp-1.4.jar commons-io/2.16.1//commons-io-2.16.1.jar -commons-lang/2.6//commons-lang-2.6.jar commons-lang3/3.14.0//commons-lang3-3.14.0.jar commons-math3/3.6.1//commons-math3-3.6.1.jar commons-pool/1.5.4//commons-pool-1.5.4.jar @@ -81,19 +80,19 @@ hadoop-cloud-storage/3.4.0//hadoop-cloud-storage-3.4.0.jar hadoop-huaweicloud/3.4.0//hadoop-huaweicloud-3.4.0.jar hadoop-shaded-guava/1.2.0//hadoop-shaded-guava-1.2.0.jar hadoop-yarn-server-web-proxy/3.4.0//hadoop-yarn-server-web-proxy-3.4.0.jar -hive-beeline/2.3.9//hive-beeline-2.3.9.jar -hive-cli/2.3.9//hive-cli-2.3.9.jar -hive-common/2.3.9//hive-common-2.3.9.jar -hive-exec/2.3.9/core/hive-exec-2.3.9-core.jar -hive-jdbc/2.3.9//hive-jdbc-2.3.9.jar -hive-llap-common/2.3.9//hive-llap-common-2.3.9.jar -hive-metastore/2.3.9//hive-metastore-2.3.9.jar -hive-serde/2.3.9//hive-serde-2.3.9.jar +hive-beeline/2.3.10//hive-beeline-2.3.10.jar +hive-cli/2.3.10//hive-cli-2.3.10.jar +hive-common/2.3.10//hive-common-2.3.10.jar +hive-exec/2.3.10/core/hive-exec-2.3.10-core.jar +hive-jdbc/2.3.10//hive-jdbc-2.3.10.jar +hive-llap-common/2.3.10//hive-llap-common-2.3.10.jar +hive-metastore/2.3.10//hive-metastore-2.3.10.jar +hive-serde/2.3.10//hive-serde-2.3.10.jar hive-service-rpc/4.0.0//hive-service-rpc-4.0.0.jar -hive-shims-0.23/2.3.9//hive-shims-0.23-2.3.9.jar -hive-shims-common/2.3.9//hive-shims-common-2.3.9.jar -hive-shims-scheduler/2.3.9//hive-shims-scheduler-2.3.9.jar -hive-shims/2.3.9//hive-shims-2.3.9.jar +hive-shims-0.23/2.3.10//hive-shims-0.23-2.3.10.jar +hive-shims-common/2.3.10//hive-shims-common-2.3.10.jar +hive-shims-scheduler/2.3.10//hive-shims-scheduler-2.3.10.jar +hive-shims/2.3.10//hive-shims-2.3.10.jar hive-storage-api/2.8.1//hive-storage-api-2.8.1.jar hk2-api/3.0.3//hk2-api-3.0.3.jar hk2-locator/3.0.3//hk2-locator-3.0.3.jar @@ -184,7 +183,7 @@ kubernetes-model-storageclass/6.12.1//kubernetes-model-storageclass-6.12.1.jar lapack/3.0.3//lapack-3.0.3.jar leveldbjni-all/1.8//leveldbjni-all-1.8.jar libfb303/0.9.3//libfb303-0.9.3.jar -libthrift/0.12.0//libthrift-0.12.0.jar +libthrift/0.16.0//libthrift-0.16.0.jar log4j-1.2-api/2.22.1//log4j-1.2-api-2.22.1.jar log4j-api/2.22.1//log4j-api-2.22.1.jar log4j-core/2.22.1//log4j-core-2.22.1.jar diff --git a/docs/building-spark.md b/docs/building-spark.md index 73fc31610d95d..8b04ac9b4a34f 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -85,9 +85,9 @@ Example: To enable Hive integration for Spark SQL along with its JDBC server and CLI, add the `-Phive` and `-Phive-thriftserver` profiles to your existing build options. -By default Spark will build with Hive 2.3.9. +By default Spark will build with Hive 2.3.10. - # With Hive 2.3.9 support + # With Hive 2.3.10 support ./build/mvn -Pyarn -Phive -Phive-thriftserver -DskipTests clean package ## Packaging without Hadoop Dependencies for YARN diff --git a/docs/sql-data-sources-hive-tables.md b/docs/sql-data-sources-hive-tables.md index b51cde53bd8fd..566dcb33a25d9 100644 --- a/docs/sql-data-sources-hive-tables.md +++ b/docs/sql-data-sources-hive-tables.md @@ -127,10 +127,10 @@ The following options can be used to configure the version of Hive that is used Property NameDefaultMeaningSince Version spark.sql.hive.metastore.version - 2.3.9 + 2.3.10 Version of the Hive metastore. Available - options are 2.0.0 through 2.3.9 and 3.0.0 through 3.1.3. + options are 2.0.0 through 2.3.10 and 3.0.0 through 3.1.3. 1.4.0 @@ -142,9 +142,9 @@ The following options can be used to configure the version of Hive that is used property can be one of four options:

  1. builtin
  2. - Use Hive 2.3.9, which is bundled with the Spark assembly when -Phive is + Use Hive 2.3.10, which is bundled with the Spark assembly when -Phive is enabled. When this option is chosen, spark.sql.hive.metastore.version must be - either 2.3.9 or not defined. + either 2.3.10 or not defined.
  3. maven
  4. Use Hive jars of specified version downloaded from Maven repositories. This configuration is not generally recommended for production deployments. diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index 8b55fb48b8b57..d95d2893f6d79 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -1068,7 +1068,7 @@ Python UDF registration is unchanged. Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently, Hive SerDes and UDFs are based on built-in Hive, and Spark SQL can be connected to different versions of Hive Metastore -(from 0.12.0 to 2.3.9 and 3.0.0 to 3.1.3. Also see [Interacting with Different Versions of Hive Metastore](sql-data-sources-hive-tables.html#interacting-with-different-versions-of-hive-metastore)). +(from 2.0.0 to 2.3.10 and 3.0.0 to 3.1.3. Also see [Interacting with Different Versions of Hive Metastore](sql-data-sources-hive-tables.html#interacting-with-different-versions-of-hive-metastore)). #### Deploying in Existing Hive Warehouses {:.no_toc} diff --git a/pom.xml b/pom.xml index 678455e6e2482..56a34cedde516 100644 --- a/pom.xml +++ b/pom.xml @@ -132,8 +132,8 @@ org.apache.hive core - 2.3.9 - 2.3.9 + 2.3.10 + 2.3.10 2.3 @@ -192,8 +192,6 @@ 1.17.0 1.26.1 2.16.1 - - 2.6 3.14.0 @@ -206,7 +204,7 @@ 3.5.2 3.0.0 2.2.11 - 0.12.0 + 0.16.0 4.13.1 1.1 4.17.0 @@ -615,11 +613,6 @@ commons-text 1.12.0 - - commons-lang - commons-lang - ${commons-lang2.version} - commons-io commons-io @@ -2294,8 +2287,8 @@ janino - org.pentaho - pentaho-aggdesigner-algorithm + net.hydromatic + aggdesigner-algorithm @@ -2365,6 +2358,10 @@ org.codehaus.groovy groovy-all + + com.lmax + disruptor + @@ -2805,6 +2802,10 @@ org.slf4j slf4j-api + + javax.annotation + javax.annotation-api + @@ -2898,12 +2899,6 @@ hive-storage-api ${hive.storage.version} ${hive.storage.scope} - - - commons-lang - commons-lang - - commons-cli diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/KerberosSaslHelper.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/KerberosSaslHelper.java index 175412ed98c6c..ef91f94eeec2b 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/KerberosSaslHelper.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/KerberosSaslHelper.java @@ -30,6 +30,7 @@ import org.apache.thrift.TProcessorFactory; import org.apache.thrift.transport.TSaslClientTransport; import org.apache.thrift.transport.TTransport; +import org.apache.thrift.transport.TTransportException; public final class KerberosSaslHelper { @@ -68,8 +69,8 @@ public static TTransport createSubjectAssumedTransport(String principal, new TSaslClientTransport("GSSAPI", null, names[0], names[1], saslProps, null, underlyingTransport); return new TSubjectAssumingTransport(saslTransport); - } catch (SaslException se) { - throw new IOException("Could not instantiate SASL transport", se); + } catch (SaslException | TTransportException se) { + throw new IOException("Could not instantiate transport", se); } } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/PlainSaslHelper.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/PlainSaslHelper.java index c06f6ec34653f..5ac29950f4f85 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/PlainSaslHelper.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/PlainSaslHelper.java @@ -38,6 +38,7 @@ import org.apache.thrift.transport.TSaslClientTransport; import org.apache.thrift.transport.TSaslServerTransport; import org.apache.thrift.transport.TTransport; +import org.apache.thrift.transport.TTransportException; import org.apache.thrift.transport.TTransportFactory; public final class PlainSaslHelper { @@ -64,7 +65,7 @@ public static TTransportFactory getPlainTransportFactory(String authTypeStr) } public static TTransport getPlainTransport(String username, String password, - TTransport underlyingTransport) throws SaslException { + TTransport underlyingTransport) throws SaslException, TTransportException { return new TSaslClientTransport("PLAIN", null, null, null, new HashMap(), new PlainCallbackHandler(username, password), underlyingTransport); } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/TSetIpAddressProcessor.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/TSetIpAddressProcessor.java index 8e7d8e60c176b..3b24ad1ebe14f 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/TSetIpAddressProcessor.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/TSetIpAddressProcessor.java @@ -46,11 +46,12 @@ public TSetIpAddressProcessor(Iface iface) { } @Override - public boolean process(final TProtocol in, final TProtocol out) throws TException { + public void process(final TProtocol in, final TProtocol out) throws TException { setIpAddress(in); setUserName(in); try { - return super.process(in, out); + super.process(in, out); + return; } finally { THREAD_LOCAL_USER_NAME.remove(); THREAD_LOCAL_IP_ADDRESS.remove(); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftBinaryCLIService.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftBinaryCLIService.java index 4d99496876fdc..c7fa7b5f3e0ac 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftBinaryCLIService.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftBinaryCLIService.java @@ -91,16 +91,10 @@ protected void initializeServer() { // Server args int maxMessageSize = hiveConf.getIntVar(HiveConf.ConfVars.HIVE_SERVER2_THRIFT_MAX_MESSAGE_SIZE); - int requestTimeout = (int) hiveConf.getTimeVar( - HiveConf.ConfVars.HIVE_SERVER2_THRIFT_LOGIN_TIMEOUT, TimeUnit.SECONDS); - int beBackoffSlotLength = (int) hiveConf.getTimeVar( - HiveConf.ConfVars.HIVE_SERVER2_THRIFT_LOGIN_BEBACKOFF_SLOT_LENGTH, TimeUnit.MILLISECONDS); TThreadPoolServer.Args sargs = new TThreadPoolServer.Args(serverSocket) .processorFactory(processorFactory).transportFactory(transportFactory) .protocolFactory(new TBinaryProtocol.Factory()) .inputProtocolFactory(new TBinaryProtocol.Factory(true, true, maxMessageSize, maxMessageSize)) - .requestTimeout(requestTimeout).requestTimeoutUnit(TimeUnit.SECONDS) - .beBackoffSlotLength(beBackoffSlotLength).beBackoffSlotLengthUnit(TimeUnit.MILLISECONDS) .executorService(executorService); // TCP Server diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java index 752cd54af626b..defe51bc97993 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java @@ -86,6 +86,16 @@ public void setSessionHandle(SessionHandle sessionHandle) { public SessionHandle getSessionHandle() { return sessionHandle; } + + @Override + public T unwrap(Class aClass) { + return null; + } + + @Override + public boolean isWrapperFor(Class aClass) { + return false; + } } public ThriftCLIService(CLIService service, String serviceName) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index 101d31d609852..30201dcee552d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -74,7 +74,7 @@ private[spark] object HiveUtils extends Logging { val HIVE_METASTORE_VERSION = buildStaticConf("spark.sql.hive.metastore.version") .doc("Version of the Hive metastore. Available options are " + - "2.0.0 through 2.3.9 and " + + "2.0.0 through 2.3.10 and " + "3.0.0 through 3.1.3.") .version("1.4.0") .stringConf diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 564c87a0fca8e..d172af21a9170 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -59,13 +59,12 @@ package object client { "org.pentaho:pentaho-aggdesigner-algorithm")) // Since HIVE-23980, calcite-core included in Hive package jar. - case object v2_3 extends HiveVersion("2.3.9", + case object v2_3 extends HiveVersion("2.3.10", exclusions = Seq("org.apache.calcite:calcite-core", "org.apache.calcite:calcite-druid", "org.apache.calcite.avatica:avatica", - "com.fasterxml.jackson.core:*", "org.apache.curator:*", - "org.pentaho:pentaho-aggdesigner-algorithm", + "net.hydromatic:aggdesigner-algorithm", "org.apache.hive:hive-vector-code-gen")) // Since Hive 3.0, HookUtils uses org.apache.logging.log4j.util.Strings diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 726341ffdf9e3..95baffdee06cb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -211,7 +211,6 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { tryDownloadSpark(version, sparkTestingDir.getCanonicalPath) } - // Extract major.minor for testing Spark 3.1.x and 3.0.x with metastore 2.3.9 and Java 11. val hiveMetastoreVersion = """^\d+\.\d+""".r.findFirstIn(hiveVersion).get val args = Seq( "--name", "prepare testing tables", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index c7aa412959097..e88a37f019b7d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -149,7 +149,7 @@ class HiveSparkSubmitSuite "--conf", s"${EXECUTOR_MEMORY.key}=512m", "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", - "--conf", "spark.sql.hive.metastore.version=2.3.9", + "--conf", "spark.sql.hive.metastore.version=2.3.10", "--conf", "spark.sql.hive.metastore.jars=maven", "--driver-java-options", "-Dderby.system.durability=test", unusedJar.toString) @@ -370,7 +370,7 @@ class HiveSparkSubmitSuite "--master", "local-cluster[2,1,512]", "--conf", s"${EXECUTOR_MEMORY.key}=512m", "--conf", s"${LEGACY_TIME_PARSER_POLICY.key}=LEGACY", - "--conf", s"${HiveUtils.HIVE_METASTORE_VERSION.key}=2.3.9", + "--conf", s"${HiveUtils.HIVE_METASTORE_VERSION.key}=2.3.10", "--conf", s"${HiveUtils.HIVE_METASTORE_JARS.key}=maven", "--conf", s"spark.hadoop.javax.jdo.option.ConnectionURL=$metastore", unusedJar.toString) @@ -387,7 +387,7 @@ object SetMetastoreURLTest extends Logging { val builder = SparkSession.builder() .config(sparkConf) .config(UI_ENABLED.key, "false") - .config(HiveUtils.HIVE_METASTORE_VERSION.key, "2.3.9") + .config(HiveUtils.HIVE_METASTORE_VERSION.key, "2.3.10") // The issue described in SPARK-16901 only appear when // spark.sql.hive.metastore.jars is not set to builtin. .config(HiveUtils.HIVE_METASTORE_JARS.key, "maven") @@ -698,7 +698,7 @@ object SparkSQLConfTest extends Logging { val filteredSettings = super.getAll.filterNot(e => isMetastoreSetting(e._1)) // Always add these two metastore settings at the beginning. - (HiveUtils.HIVE_METASTORE_VERSION.key -> "2.3.9") +: + (HiveUtils.HIVE_METASTORE_VERSION.key -> "2.3.10") +: (HiveUtils.HIVE_METASTORE_JARS.key -> "maven") +: filteredSettings } @@ -726,7 +726,7 @@ object SPARK_9757 extends QueryTest { val hiveWarehouseLocation = Utils.createTempDir() val sparkContext = new SparkContext( new SparkConf() - .set(HiveUtils.HIVE_METASTORE_VERSION.key, "2.3.9") + .set(HiveUtils.HIVE_METASTORE_VERSION.key, "2.3.10") .set(HiveUtils.HIVE_METASTORE_JARS.key, "maven") .set(UI_ENABLED, false) .set(WAREHOUSE_PATH.key, hiveWarehouseLocation.toString)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 397da6c18b50a..5e58959ca4f7d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -1627,10 +1627,8 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd test("SPARK-33084: Add jar support Ivy URI in SQL") { val testData = TestHive.getHiveFile("data/files/sample.json").toURI withTable("t") { - // hive-catalog-core has some transitive dependencies which dont exist on maven central - // and hence cannot be found in the test environment or are non-jar (.pom) which cause - // failures in tests. Use transitive=false as it should be good enough to test the Ivy - // support in Hive ADD JAR + // Use transitive=false as it should be good enough to test the Ivy support + // in Hive ADD JAR sql(s"ADD JAR ivy://org.apache.hive.hcatalog:hive-hcatalog-core:$hiveVersion" + "?transitive=false") sql( From b371e7dd88009195740f8f5b591447441ea43d0b Mon Sep 17 00:00:00 2001 From: Harsh Motwani Date: Thu, 9 May 2024 21:47:05 -0700 Subject: [PATCH 41/65] [SPARK-48224][SQL] Disallow map keys from being of variant type ### What changes were proposed in this pull request? This PR disallows map keys from being of variant type. Therefore, SQL statements like `select map(parse_json('{"a": 1}'), 1)`, which would work earlier, will throw an exception now. ### Why are the changes needed? Allowing variant to be the key type of a map can result in undefined behavior as this has not been tested. ### Does this PR introduce _any_ user-facing change? Yes, users could use variants as keys in maps earlier. However, this PR disallows this possibility. ### How was this patch tested? Unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #46516 from harshmotw-db/map_variant_key. Authored-by: Harsh Motwani Signed-off-by: Dongjoon Hyun --- .../spark/sql/catalyst/util/TypeUtils.scala | 2 +- .../expressions/ComplexTypeSuite.scala | 34 ++++++++++++++++++- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index d2c708b380cf5..a0d578c66e736 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -58,7 +58,7 @@ object TypeUtils extends QueryErrorsBase { } def checkForMapKeyType(keyType: DataType): TypeCheckResult = { - if (keyType.existsRecursively(_.isInstanceOf[MapType])) { + if (keyType.existsRecursively(dt => dt.isInstanceOf[MapType] || dt.isInstanceOf[VariantType])) { DataTypeMismatch( errorSubClass = "INVALID_MAP_KEY_TYPE", messageParameters = Map( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 5f135e46a3775..497b335289b11 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.TypeUtils.ordinalNumber import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{UTF8String, VariantVal} class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -359,6 +359,38 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { ) } + // map key can't be variant + val map6 = CreateMap(Seq( + Literal.create(new VariantVal(Array[Byte](), Array[Byte]())), + Literal.create(1) + )) + map6.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => fail("should not allow variant as a part of map key") + case TypeCheckResult.DataTypeMismatch(errorSubClass, messageParameters) => + assert(errorSubClass == "INVALID_MAP_KEY_TYPE") + assert(messageParameters === Map("keyType" -> "\"VARIANT\"")) + } + + // map key can't contain variant + val map7 = CreateMap( + Seq( + CreateStruct( + Seq(Literal.create(1), Literal.create(new VariantVal(Array[Byte](), Array[Byte]()))) + ), + Literal.create(1) + ) + ) + map7.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => fail("should not allow variant as a part of map key") + case TypeCheckResult.DataTypeMismatch(errorSubClass, messageParameters) => + assert(errorSubClass == "INVALID_MAP_KEY_TYPE") + assert( + messageParameters === Map( + "keyType" -> "\"STRUCT\"" + ) + ) + } + test("MapFromArrays") { val intSeq = Seq(5, 10, 15, 20, 25) val longSeq = intSeq.map(_.toLong) From 9bb15db85e53b69b9c0ba112cd1dd93d8213eea4 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 9 May 2024 22:01:13 -0700 Subject: [PATCH 42/65] [SPARK-48228][PYTHON][CONNECT] Implement the missing function validation in ApplyInXXX ### What changes were proposed in this pull request? Implement the missing function validation in ApplyInXXX https://github.com/apache/spark/pull/46397 fixed this issue for `Cogrouped.ApplyInPandas`, this PR fix remaining methods. ### Why are the changes needed? for better error message: ``` In [12]: df1 = spark.range(11) In [13]: df2 = df1.groupby("id").applyInPandas(lambda: 1, StructType([StructField("d", DoubleType())])) In [14]: df2.show() ``` before this PR, an invalid function causes weird execution errors: ``` 24/05/10 11:37:36 ERROR Executor: Exception in task 0.0 in stage 10.0 (TID 36) org.apache.spark.api.python.PythonException: Traceback (most recent call last): File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/worker.py", line 1834, in main process() File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/worker.py", line 1826, in process serializer.dump_stream(out_iter, outfile) File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 531, in dump_stream return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 104, in dump_stream for batch in iterator: File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 524, in init_stream_yield_batches for series in iterator: File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/worker.py", line 1610, in mapper return f(keys, vals) ^^^^^^^^^^^^^ File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/worker.py", line 488, in return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))] ^^^^^^^^^^^^^ File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/worker.py", line 483, in wrapped result, return_type, _assign_cols_by_name, truncate_return_schema=False ^^^^^^ UnboundLocalError: cannot access local variable 'result' where it is not associated with a value at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:523) at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:117) at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:479) at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:601) at scala.collection.Iterator$$anon$9.hasNext(Iterator.scala:583) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:50) at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:388) at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:896) ... ``` After this PR, the error happens before execution, which is consistent with Spark Classic, and much clear ``` PySparkValueError: [INVALID_PANDAS_UDF] Invalid function: pandas_udf with function type GROUPED_MAP or the function in groupby.applyInPandas must take either one argument (data) or two arguments (key, data). ``` ### Does this PR introduce _any_ user-facing change? yes, error message changes ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #46519 from zhengruifeng/missing_check_in_group. Authored-by: Ruifeng Zheng Signed-off-by: Dongjoon Hyun --- python/pyspark/sql/connect/group.py | 8 ++++++-- python/pyspark/sql/pandas/functions.py | 4 ++-- .../tests/pandas/test_pandas_grouped_map.py | 20 +++++++++++++++++++ 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py index c916e8acf3e43..2a5bb5939a3f8 100644 --- a/python/pyspark/sql/connect/group.py +++ b/python/pyspark/sql/connect/group.py @@ -34,6 +34,7 @@ from pyspark.util import PythonEvalType from pyspark.sql.group import GroupedData as PySparkGroupedData from pyspark.sql.pandas.group_ops import PandasCogroupedOps as PySparkPandasCogroupedOps +from pyspark.sql.pandas.functions import _validate_pandas_udf # type: ignore[attr-defined] from pyspark.sql.types import NumericType from pyspark.sql.types import StructType @@ -293,6 +294,7 @@ def applyInPandas( from pyspark.sql.connect.udf import UserDefinedFunction from pyspark.sql.connect.dataframe import DataFrame + _validate_pandas_udf(func, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) udf_obj = UserDefinedFunction( func, returnType=schema, @@ -322,6 +324,7 @@ def applyInPandasWithState( from pyspark.sql.connect.udf import UserDefinedFunction from pyspark.sql.connect.dataframe import DataFrame + _validate_pandas_udf(func, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE) udf_obj = UserDefinedFunction( func, returnType=outputStructType, @@ -360,6 +363,7 @@ def applyInArrow( from pyspark.sql.connect.udf import UserDefinedFunction from pyspark.sql.connect.dataframe import DataFrame + _validate_pandas_udf(func, PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF) udf_obj = UserDefinedFunction( func, returnType=schema, @@ -398,9 +402,8 @@ def applyInPandas( ) -> "DataFrame": from pyspark.sql.connect.udf import UserDefinedFunction from pyspark.sql.connect.dataframe import DataFrame - from pyspark.sql.pandas.functions import _validate_pandas_udf # type: ignore[attr-defined] - _validate_pandas_udf(func, schema, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF) + _validate_pandas_udf(func, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF) udf_obj = UserDefinedFunction( func, returnType=schema, @@ -426,6 +429,7 @@ def applyInArrow( from pyspark.sql.connect.udf import UserDefinedFunction from pyspark.sql.connect.dataframe import DataFrame + _validate_pandas_udf(func, PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF) udf_obj = UserDefinedFunction( func, returnType=schema, diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py index 5922a5ced8639..020105bb064ae 100644 --- a/python/pyspark/sql/pandas/functions.py +++ b/python/pyspark/sql/pandas/functions.py @@ -432,7 +432,7 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: # validate the pandas udf and return the adjusted eval type -def _validate_pandas_udf(f, returnType, evalType) -> int: +def _validate_pandas_udf(f, evalType) -> int: argspec = getfullargspec(f) # pandas UDF by type hints. @@ -533,7 +533,7 @@ def _validate_pandas_udf(f, returnType, evalType) -> int: def _create_pandas_udf(f, returnType, evalType): - evalType = _validate_pandas_udf(f, returnType, evalType) + evalType = _validate_pandas_udf(f, evalType) if is_remote(): from pyspark.sql.connect.udf import _create_udf as _create_connect_udf diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py index 1e86e12eb74f0..a26d6d02a2bcd 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py @@ -439,6 +439,26 @@ def check_wrong_args(self): pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())])) ) + def test_wrong_args_in_apply_func(self): + df1 = self.spark.range(11) + df2 = self.spark.range(22) + + with self.assertRaisesRegex(PySparkValueError, "INVALID_PANDAS_UDF"): + df1.groupby("id").applyInPandas(lambda: 1, StructType([StructField("d", DoubleType())])) + + with self.assertRaisesRegex(PySparkValueError, "INVALID_PANDAS_UDF"): + df1.groupby("id").applyInArrow(lambda: 1, StructType([StructField("d", DoubleType())])) + + with self.assertRaisesRegex(PySparkValueError, "INVALID_PANDAS_UDF"): + df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas( + lambda: 1, StructType([StructField("d", DoubleType())]) + ) + + with self.assertRaisesRegex(PySparkValueError, "INVALID_PANDAS_UDF"): + df1.groupby("id").cogroup(df2.groupby("id")).applyInArrow( + lambda: 1, StructType([StructField("d", DoubleType())]) + ) + def test_unsupported_types(self): with self.quiet(): self.check_unsupported_types() From 8ccc8b92be50b1d5ef932873403e62e28c478781 Mon Sep 17 00:00:00 2001 From: Chloe He Date: Thu, 9 May 2024 22:07:04 -0700 Subject: [PATCH 43/65] [SPARK-48201][DOCS][PYTHON] Make some corrections in the docstring of pyspark DataStreamReader methods ### What changes were proposed in this pull request? The docstrings of the pyspark DataStream Reader methods `csv()` and `text()` say that the `path` parameter can be a list, but actually when a list is passed an error is raised. ### Why are the changes needed? Documentation is wrong. ### Does this PR introduce _any_ user-facing change? Yes. Fixes documentation. ### How was this patch tested? N/A ### Was this patch authored or co-authored using generative AI tooling? No Closes #46416 from chloeh13q/fix/streamread-docstring. Authored-by: Chloe He Signed-off-by: Dongjoon Hyun --- python/pyspark/sql/streaming/readwriter.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/streaming/readwriter.py b/python/pyspark/sql/streaming/readwriter.py index c2b75dd8f167a..b202a499e8b08 100644 --- a/python/pyspark/sql/streaming/readwriter.py +++ b/python/pyspark/sql/streaming/readwriter.py @@ -553,8 +553,8 @@ def text( Parameters ---------- - path : str or list - string, or list of strings, for input path(s). + path : str + string for input path. Other Parameters ---------------- @@ -641,8 +641,8 @@ def csv( Parameters ---------- - path : str or list - string, or list of strings, for input path(s). + path : str + string for input path. schema : :class:`pyspark.sql.types.StructType` or str, optional an optional :class:`pyspark.sql.types.StructType` for the input schema or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). From 33cac4436e593c9c501c5ff0eedf923d3a21899c Mon Sep 17 00:00:00 2001 From: Cheng Pan Date: Thu, 9 May 2024 22:55:07 -0700 Subject: [PATCH 44/65] [SPARK-47847][CORE] Deprecate `spark.network.remoteReadNioBufferConversion` ### What changes were proposed in this pull request? `spark.network.remoteReadNioBufferConversion` was introduced in https://github.com/apache/spark/commit/2c82745686f4456c4d5c84040a431dcb5b6cb60b, to allow disable [SPARK-24307](https://issues.apache.org/jira/browse/SPARK-24307) for safety, while during the whole Spark 3 period, there are no negative reports, it proves that [SPARK-24307](https://issues.apache.org/jira/browse/SPARK-24307) is solid enough, I propose to mark it deprecated in 3.5.2 and remove in 4.1.0 or later ### Why are the changes needed? Code clean up ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No Closes #46047 from pan3793/SPARK-47847. Authored-by: Cheng Pan Signed-off-by: Dongjoon Hyun --- core/src/main/scala/org/apache/spark/SparkConf.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index fa7911c937378..95955455a9d4b 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -645,7 +645,9 @@ private[spark] object SparkConf extends Logging { DeprecatedConfig("spark.blacklist.killBlacklistedExecutors", "3.1.0", "Please use spark.excludeOnFailure.killExcludedExecutors"), DeprecatedConfig("spark.yarn.blacklist.executor.launch.blacklisting.enabled", "3.1.0", - "Please use spark.yarn.executor.launch.excludeOnFailure.enabled") + "Please use spark.yarn.executor.launch.excludeOnFailure.enabled"), + DeprecatedConfig("spark.network.remoteReadNioBufferConversion", "3.5.2", + "Please open a JIRA ticket to report it if you need to use this configuration.") ) Map(configs.map { cfg => (cfg.key -> cfg) } : _*) From 2df494fd4e4e64b9357307fb0c5e8fc1b7491ac3 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Fri, 10 May 2024 14:03:08 +0800 Subject: [PATCH 45/65] [SPARK-48158][SQL] Add collation support for XML expressions ### What changes were proposed in this pull request? Introduce collation awareness for XML expressions: from_xml, schema_of_xml, to_xml. ### Why are the changes needed? Add collation support for XML expressions in Spark. ### Does this PR introduce _any_ user-facing change? Yes, users should now be able to use collated strings within arguments for XML functions: from_xml, schema_of_xml, to_xml. ### How was this patch tested? E2e sql tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46507 from uros-db/xml-expressions. Authored-by: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Signed-off-by: Wenchen Fan --- .../catalyst/expressions/xmlExpressions.scala | 9 +- .../sql/CollationSQLExpressionsSuite.scala | 124 ++++++++++++++++++ 2 files changed, 129 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala index 415d55d19ded2..237d740e04362 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.util.TypeUtils._ import org.apache.spark.sql.catalyst.xml.{StaxXmlGenerator, StaxXmlParser, ValidatorUtil, XmlInferSchema, XmlOptions} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.types.StringTypeAnyCollation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -140,7 +141,7 @@ case class XmlToStructs( converter(parser.parse(str)) } - override def inputTypes: Seq[AbstractDataType] = StringType :: Nil + override def inputTypes: Seq[AbstractDataType] = StringTypeAnyCollation :: Nil override def sql: String = schema match { case _: MapType => "entries" @@ -178,7 +179,7 @@ case class SchemaOfXml( child = child, options = ExprUtils.convertToMapData(options)) - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = false @@ -226,7 +227,7 @@ case class SchemaOfXml( .map(ArrayType(_, containsNull = at.containsNull)) .getOrElse(ArrayType(StructType(Nil), containsNull = at.containsNull)) case other: DataType => - xmlInferSchema.canonicalizeType(other).getOrElse(StringType) + xmlInferSchema.canonicalizeType(other).getOrElse(SQLConf.get.defaultStringType) } UTF8String.fromString(dataType.sql) @@ -320,7 +321,7 @@ case class StructsToXml( getAndReset() } - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index 2b6390151bb9b..dd5703d1284a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.text.SimpleDateFormat + import scala.collection.immutable.Seq import org.apache.spark.{SparkException, SparkIllegalArgumentException, SparkRuntimeException} @@ -860,6 +862,128 @@ class CollationSQLExpressionsSuite assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") } + test("Support XmlToStructs xml expression with collation") { + case class XmlToStructsTestCase( + input: String, + collationName: String, + schema: String, + options: String, + result: Row, + structFields: Seq[StructField] + ) + + val testCases = Seq( + XmlToStructsTestCase("

    1

    ", "UTF8_BINARY", "'a INT'", "", + Row(1), Seq( + StructField("a", IntegerType, nullable = true) + )), + XmlToStructsTestCase("

    true0.8

    ", "UTF8_BINARY_LCASE", + "'A BOOLEAN, B DOUBLE'", "", Row(true, 0.8), Seq( + StructField("A", BooleanType, nullable = true), + StructField("B", DoubleType, nullable = true) + )), + XmlToStructsTestCase("

    Spark

    ", "UNICODE", "'s STRING'", "", + Row("Spark"), Seq( + StructField("s", StringType("UNICODE"), nullable = true) + )), + XmlToStructsTestCase("

    ", "UNICODE_CI", "'time Timestamp'", + ", map('timestampFormat', 'dd/MM/yyyy')", Row( + new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.S").parse("2015-08-26 00:00:00.0") + ), Seq( + StructField("time", TimestampType, nullable = true) + )) + ) + + // Supported collations + testCases.foreach(t => { + val query = + s""" + |select from_xml('${t.input}', ${t.schema} ${t.options}) + |""".stripMargin + // Result + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + checkAnswer(testQuery, Row(t.result)) + val dataType = StructType(t.structFields) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + }) + } + + test("Support SchemaOfXml xml expression with collation") { + case class SchemaOfXmlTestCase( + input: String, + collationName: String, + result: String + ) + + val testCases = Seq( + SchemaOfXmlTestCase("

    1

    ", "UTF8_BINARY", "STRUCT"), + SchemaOfXmlTestCase("

    true0.8

    ", "UTF8_BINARY_LCASE", + "STRUCT"), + SchemaOfXmlTestCase("

    ", "UNICODE", "STRUCT<>"), + SchemaOfXmlTestCase("

    123

    ", "UNICODE_CI", + "STRUCT>") + ) + + // Supported collations + testCases.foreach(t => { + val query = + s""" + |select schema_of_xml('${t.input}') + |""".stripMargin + // Result + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + checkAnswer(testQuery, Row(t.result)) + val dataType = StringType(t.collationName) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + }) + } + + test("Support StructsToXml xml expression with collation") { + case class StructsToXmlTestCase( + input: String, + collationName: String, + result: String + ) + + val testCases = Seq( + StructsToXmlTestCase("named_struct('a', 1, 'b', 2)", "UTF8_BINARY", + s""" + | 1 + | 2 + |""".stripMargin), + StructsToXmlTestCase("named_struct('A', true, 'B', 2.0)", "UTF8_BINARY_LCASE", + s""" + | true + | 2.0 + |""".stripMargin), + StructsToXmlTestCase("named_struct()", "UNICODE", + ""), + StructsToXmlTestCase("named_struct('time', to_timestamp('2015-08-26'))", "UNICODE_CI", + s""" + | + |""".stripMargin) + ) + + // Supported collations + testCases.foreach(t => { + val query = + s""" + |select to_xml(${t.input}) + |""".stripMargin + // Result + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + checkAnswer(testQuery, Row(t.result)) + val dataType = StringType(t.collationName) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + }) + } + test("Support ParseJson & TryParseJson variant expressions with collation") { case class ParseJsonTestCase( input: String, From d8151186d79459fbde27a01bd97328e73548c55a Mon Sep 17 00:00:00 2001 From: Cheng Pan Date: Fri, 10 May 2024 01:09:01 -0700 Subject: [PATCH 46/65] [SPARK-48230][BUILD] Remove unused `jodd-core` ### What changes were proposed in this pull request? Remove a jar that has CVE https://github.com/advisories/GHSA-jrg3-qq99-35g7 ### Why are the changes needed? Previously, `jodd-core` came from Hive transitive deps, while https://github.com/apache/hive/pull/5151 (Hive 2.3.10) cut it out, so we can remove it from Spark now. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46520 from pan3793/SPARK-48230. Authored-by: Cheng Pan Signed-off-by: Dongjoon Hyun --- LICENSE-binary | 1 - dev/deps/spark-deps-hadoop-3-hive-2.3 | 1 - licenses-binary/LICENSE-jodd.txt | 24 ------------------------ pom.xml | 6 ------ sql/hive/pom.xml | 4 ---- 5 files changed, 36 deletions(-) delete mode 100644 licenses-binary/LICENSE-jodd.txt diff --git a/LICENSE-binary b/LICENSE-binary index 40271c9924bc4..034215f0ab157 100644 --- a/LICENSE-binary +++ b/LICENSE-binary @@ -436,7 +436,6 @@ com.esotericsoftware:reflectasm org.codehaus.janino:commons-compiler org.codehaus.janino:janino jline:jline -org.jodd:jodd-core com.github.wendykierp:JTransforms pl.edu.icm:JLargeArrays diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 392bacd73277f..29997815e5bc1 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -143,7 +143,6 @@ jline/2.14.6//jline-2.14.6.jar jline/3.24.1//jline-3.24.1.jar jna/5.13.0//jna-5.13.0.jar joda-time/2.12.7//joda-time-2.12.7.jar -jodd-core/3.5.2//jodd-core-3.5.2.jar jpam/1.1//jpam-1.1.jar json/1.8//json-1.8.jar json4s-ast_2.13/4.0.7//json4s-ast_2.13-4.0.7.jar diff --git a/licenses-binary/LICENSE-jodd.txt b/licenses-binary/LICENSE-jodd.txt deleted file mode 100644 index cc6b458adb386..0000000000000 --- a/licenses-binary/LICENSE-jodd.txt +++ /dev/null @@ -1,24 +0,0 @@ -Copyright (c) 2003-present, Jodd Team (https://jodd.org) -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -1. Redistributions of source code must retain the above copyright notice, -this list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright -notice, this list of conditions and the following disclaimer in the -documentation and/or other materials provided with the distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/pom.xml b/pom.xml index 56a34cedde516..a98efe8aed1e6 100644 --- a/pom.xml +++ b/pom.xml @@ -201,7 +201,6 @@ 3.1.9 3.0.12 2.12.7 - 3.5.2 3.0.0 2.2.11 0.16.0 @@ -2783,11 +2782,6 @@ joda-time ${joda.version}
    - - org.jodd - jodd-core - ${jodd.version} - org.datanucleus datanucleus-core diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 3895d9dc5a634..5e9fc256e7e64 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -152,10 +152,6 @@ joda-time joda-time - - org.jodd - jodd-core - com.google.code.findbugs jsr305 From 256a23883d901c78cf82b4c52e3373322309b8d1 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Fri, 10 May 2024 17:12:37 +0900 Subject: [PATCH 47/65] [SPARK-48232][PYTHON][TESTS] Fix 'pyspark.sql.tests.connect.test_connect_session' in Python 3.12 build ### What changes were proposed in this pull request? This PR avoids importing `scipy.sparse` directly which hangs indeterministically specifically with Python 3.12 ### Why are the changes needed? To fix the build with Python 3.12 https://github.com/apache/spark/actions/runs/9022174253/job/24804919747 I was able to reproduce this in my local but a bit indeterministic. ### Does this PR introduce _any_ user-facing change? No, test-only. ### How was this patch tested? Manually tested in my local. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46522 from HyukjinKwon/SPARK-48232. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/testing/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py index fe25136864eef..8a7aa405e4ac7 100644 --- a/python/pyspark/testing/utils.py +++ b/python/pyspark/testing/utils.py @@ -38,7 +38,7 @@ have_scipy = False have_numpy = False try: - import scipy.sparse # noqa: F401 + import scipy # noqa: F401 have_scipy = True except ImportError: From 259760a5c5e26e33b2ee46282aeb63e4ea701020 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 10 May 2024 18:44:53 +0800 Subject: [PATCH 48/65] [SPARK-48228][PYTHON][CONNECT][FOLLOWUP] Also apply `_validate_pandas_udf` in MapInXXX ### What changes were proposed in this pull request? Also apply `_validate_pandas_udf` in MapInXXX ### Why are the changes needed? to make sure validation in `pandas_udf` is also applied in MapInXXX ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #46524 from zhengruifeng/missing_check_map_in_xxx. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- python/pyspark/sql/connect/dataframe.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 3c9415adec2dd..ccaaa15f3190c 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -83,6 +83,7 @@ ) from pyspark.sql.connect.functions import builtin as F from pyspark.sql.pandas.types import from_arrow_schema +from pyspark.sql.pandas.functions import _validate_pandas_udf # type: ignore[attr-defined] if TYPE_CHECKING: @@ -1997,6 +1998,7 @@ def _map_partitions( ) -> ParentDataFrame: from pyspark.sql.connect.udf import UserDefinedFunction + _validate_pandas_udf(func, evalType) udf_obj = UserDefinedFunction( func, returnType=schema, From 7ef0440ef22161a6160f7b9000c70b26c84eecf7 Mon Sep 17 00:00:00 2001 From: Kelvin Jiang Date: Fri, 10 May 2024 22:39:15 +0800 Subject: [PATCH 49/65] [SPARK-48146][SQL] Fix aggregate function in With expression child assertion ### What changes were proposed in this pull request? In https://github.com/apache/spark/pull/46034, there was a complicated edge case where common expression references in aggregate functions in the child of a `With` expression could become dangling. An assertion was added to avoid that case from happening, but the assertion wasn't fully accurate as a query like: ``` select id between max(if(id between 1 and 2, 2, 1)) over () and id from range(10) ``` would fail the assertion. This PR fixes the assertion to be more accurate. ### Why are the changes needed? This addresses a regression in https://github.com/apache/spark/pull/46034. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46443 from kelvinjian-db/SPARK-48146-agg. Authored-by: Kelvin Jiang Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/expressions/With.scala | 26 +++++++++++++++--- .../RewriteWithExpressionSuite.scala | 27 ++++++++++++++++++- 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala index 14deedd9c70fa..29794b33641cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE_EXPRESSION, COMMON_EXPR_REF, TreePattern, WITH_EXPRESSION} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, TreePattern, WITH_EXPRESSION} import org.apache.spark.sql.types.DataType /** @@ -27,9 +28,11 @@ import org.apache.spark.sql.types.DataType */ case class With(child: Expression, defs: Seq[CommonExpressionDef]) extends Expression with Unevaluable { - // We do not allow With to be created with an AggregateExpression in the child, as this would - // create a dangling CommonExpressionRef after rewriting it in RewriteWithExpression. - assert(!child.containsPattern(AGGREGATE_EXPRESSION)) + // We do not allow creating a With expression with an AggregateExpression that contains a + // reference to a common expression defined in that scope (note that it can contain another With + // expression with a common expression ref of the inner With). This is to prevent the creation of + // a dangling CommonExpressionRef after rewriting it in RewriteWithExpression. + assert(!With.childContainsUnsupportedAggExpr(this)) override val nodePatterns: Seq[TreePattern] = Seq(WITH_EXPRESSION) override def dataType: DataType = child.dataType @@ -92,6 +95,21 @@ object With { val commonExprRefs = commonExprDefs.map(new CommonExpressionRef(_)) With(replaced(commonExprRefs), commonExprDefs) } + + private[sql] def childContainsUnsupportedAggExpr(withExpr: With): Boolean = { + lazy val commonExprIds = withExpr.defs.map(_.id).toSet + withExpr.child.exists { + case agg: AggregateExpression => + // Check that the aggregate expression does not contain a reference to a common expression + // in the outer With expression (it is ok if it contains a reference to a common expression + // for a nested With expression). + agg.exists { + case r: CommonExpressionRef => commonExprIds.contains(r.id) + case _ => false + } + case _ => false + } + } } case class CommonExpressionId(id: Long = CommonExpressionId.newId, canonicalized: Boolean = false) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala index d482b18d93316..8f023fa4156bc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala @@ -353,7 +353,7 @@ class RewriteWithExpressionSuite extends PlanTest { ) } - test("aggregate functions in child of WITH expression is not supported") { + test("aggregate functions in child of WITH expression with ref is not supported") { val a = testRelation.output.head intercept[java.lang.AssertionError] { val expr = With(a - 1) { case Seq(ref) => @@ -366,4 +366,29 @@ class RewriteWithExpressionSuite extends PlanTest { Optimizer.execute(plan) } } + + test("WITH expression nested in aggregate function") { + val a = testRelation.output.head + val expr = With(a + 1) { case Seq(ref) => + ref * ref + } + val nestedExpr = With(a - 1) { case Seq(ref) => + ref * max(expr) + ref + } + val plan = testRelation.groupBy(a)(nestedExpr.as("col")).analyze + val commonExpr1Id = expr.defs.head.id.id + val commonExpr1Name = s"_common_expr_$commonExpr1Id" + val commonExpr2Id = nestedExpr.defs.head.id.id + val commonExpr2Name = s"_common_expr_$commonExpr2Id" + val aggExprName = "_aggregateexpression" + comparePlans( + Optimizer.execute(plan), + testRelation + .select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*) + .groupBy(a)(a, max($"$commonExpr1Name" * $"$commonExpr1Name").as(aggExprName)) + .select($"a", $"$aggExprName", (a - 1).as(commonExpr2Name)) + .select(($"$commonExpr2Name" * $"$aggExprName" + $"$commonExpr2Name").as("col")) + .analyze + ) + } } From 73bb619d45b2d0699ca4a9d251eea57c359f275b Mon Sep 17 00:00:00 2001 From: fred-db Date: Fri, 10 May 2024 07:45:28 -0700 Subject: [PATCH 50/65] [SPARK-48235][SQL] Directly pass join instead of all arguments to getBroadcastBuildSide and getShuffleHashJoinBuildSide ### What changes were proposed in this pull request? * Refactor getBroadcastBuildSide and getShuffleHashJoinBuildSide to pass the join as argument instead of all member variables of the join separately. ### Why are the changes needed? * Makes to code easier to read. ### Does this PR introduce _any_ user-facing change? * no ### How was this patch tested? * Existing UTs ### Was this patch authored or co-authored using generative AI tooling? * No Closes #46525 from fred-db/parameter-change. Authored-by: fred-db Signed-off-by: Dongjoon Hyun --- .../spark/sql/catalyst/optimizer/joins.scala | 56 ++++++++---------- .../optimizer/JoinSelectionHelperSuite.scala | 59 +++++-------------- .../spark/sql/execution/SparkStrategies.scala | 6 +- 3 files changed, 40 insertions(+), 81 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 2b4ee033b0885..5571178832db7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -289,58 +289,52 @@ case object BuildLeft extends BuildSide trait JoinSelectionHelper { def getBroadcastBuildSide( - left: LogicalPlan, - right: LogicalPlan, - joinType: JoinType, - hint: JoinHint, + join: Join, hintOnly: Boolean, conf: SQLConf): Option[BuildSide] = { val buildLeft = if (hintOnly) { - hintToBroadcastLeft(hint) + hintToBroadcastLeft(join.hint) } else { - canBroadcastBySize(left, conf) && !hintToNotBroadcastLeft(hint) + canBroadcastBySize(join.left, conf) && !hintToNotBroadcastLeft(join.hint) } val buildRight = if (hintOnly) { - hintToBroadcastRight(hint) + hintToBroadcastRight(join.hint) } else { - canBroadcastBySize(right, conf) && !hintToNotBroadcastRight(hint) + canBroadcastBySize(join.right, conf) && !hintToNotBroadcastRight(join.hint) } getBuildSide( - canBuildBroadcastLeft(joinType) && buildLeft, - canBuildBroadcastRight(joinType) && buildRight, - left, - right + canBuildBroadcastLeft(join.joinType) && buildLeft, + canBuildBroadcastRight(join.joinType) && buildRight, + join.left, + join.right ) } def getShuffleHashJoinBuildSide( - left: LogicalPlan, - right: LogicalPlan, - joinType: JoinType, - hint: JoinHint, + join: Join, hintOnly: Boolean, conf: SQLConf): Option[BuildSide] = { val buildLeft = if (hintOnly) { - hintToShuffleHashJoinLeft(hint) + hintToShuffleHashJoinLeft(join.hint) } else { - hintToPreferShuffleHashJoinLeft(hint) || - (!conf.preferSortMergeJoin && canBuildLocalHashMapBySize(left, conf) && - muchSmaller(left, right, conf)) || + hintToPreferShuffleHashJoinLeft(join.hint) || + (!conf.preferSortMergeJoin && canBuildLocalHashMapBySize(join.left, conf) && + muchSmaller(join.left, join.right, conf)) || forceApplyShuffledHashJoin(conf) } val buildRight = if (hintOnly) { - hintToShuffleHashJoinRight(hint) + hintToShuffleHashJoinRight(join.hint) } else { - hintToPreferShuffleHashJoinRight(hint) || - (!conf.preferSortMergeJoin && canBuildLocalHashMapBySize(right, conf) && - muchSmaller(right, left, conf)) || + hintToPreferShuffleHashJoinRight(join.hint) || + (!conf.preferSortMergeJoin && canBuildLocalHashMapBySize(join.right, conf) && + muchSmaller(join.right, join.left, conf)) || forceApplyShuffledHashJoin(conf) } getBuildSide( - canBuildShuffledHashJoinLeft(joinType) && buildLeft, - canBuildShuffledHashJoinRight(joinType) && buildRight, - left, - right + canBuildShuffledHashJoinLeft(join.joinType) && buildLeft, + canBuildShuffledHashJoinRight(join.joinType) && buildRight, + join.left, + join.right ) } @@ -401,10 +395,8 @@ trait JoinSelectionHelper { } def canPlanAsBroadcastHashJoin(join: Join, conf: SQLConf): Boolean = { - getBroadcastBuildSide(join.left, join.right, join.joinType, - join.hint, hintOnly = true, conf).isDefined || - getBroadcastBuildSide(join.left, join.right, join.joinType, - join.hint, hintOnly = false, conf).isDefined + getBroadcastBuildSide(join, hintOnly = true, conf).isDefined || + getBroadcastBuildSide(join, hintOnly = false, conf).isDefined } def canPruneLeft(joinType: JoinType): Boolean = joinType match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinSelectionHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinSelectionHelperSuite.scala index 6acce44922f69..61fb68cfba863 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinSelectionHelperSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinSelectionHelperSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.AttributeMap import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} -import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, JoinHint, NO_BROADCAST_HASH, SHUFFLE_HASH} +import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, Join, JoinHint, NO_BROADCAST_HASH, SHUFFLE_HASH} import org.apache.spark.sql.catalyst.statsEstimation.StatsTestPlan import org.apache.spark.sql.internal.SQLConf @@ -38,16 +38,15 @@ class JoinSelectionHelperSuite extends PlanTest with JoinSelectionHelper { size = Some(1000), attributeStats = AttributeMap(Seq())) + private val join = Join(left, right, Inner, None, JoinHint(None, None)) + private val hintBroadcast = Some(HintInfo(Some(BROADCAST))) private val hintNotToBroadcast = Some(HintInfo(Some(NO_BROADCAST_HASH))) private val hintShuffleHash = Some(HintInfo(Some(SHUFFLE_HASH))) test("getBroadcastBuildSide (hintOnly = true) return BuildLeft with only a left hint") { val broadcastSide = getBroadcastBuildSide( - left, - right, - Inner, - JoinHint(hintBroadcast, None), + join.copy(hint = JoinHint(hintBroadcast, None)), hintOnly = true, SQLConf.get ) @@ -56,10 +55,7 @@ class JoinSelectionHelperSuite extends PlanTest with JoinSelectionHelper { test("getBroadcastBuildSide (hintOnly = true) return BuildRight with only a right hint") { val broadcastSide = getBroadcastBuildSide( - left, - right, - Inner, - JoinHint(None, hintBroadcast), + join.copy(hint = JoinHint(None, hintBroadcast)), hintOnly = true, SQLConf.get ) @@ -68,10 +64,7 @@ class JoinSelectionHelperSuite extends PlanTest with JoinSelectionHelper { test("getBroadcastBuildSide (hintOnly = true) return smaller side with both having hints") { val broadcastSide = getBroadcastBuildSide( - left, - right, - Inner, - JoinHint(hintBroadcast, hintBroadcast), + join.copy(hint = JoinHint(hintBroadcast, hintBroadcast)), hintOnly = true, SQLConf.get ) @@ -80,10 +73,7 @@ class JoinSelectionHelperSuite extends PlanTest with JoinSelectionHelper { test("getBroadcastBuildSide (hintOnly = true) return None when no side has a hint") { val broadcastSide = getBroadcastBuildSide( - left, - right, - Inner, - JoinHint(None, None), + join.copy(hint = JoinHint(None, None)), hintOnly = true, SQLConf.get ) @@ -92,10 +82,7 @@ class JoinSelectionHelperSuite extends PlanTest with JoinSelectionHelper { test("getBroadcastBuildSide (hintOnly = false) return BuildRight when right is broadcastable") { val broadcastSide = getBroadcastBuildSide( - left, - right, - Inner, - JoinHint(None, None), + join.copy(hint = JoinHint(None, None)), hintOnly = false, SQLConf.get ) @@ -105,10 +92,7 @@ class JoinSelectionHelperSuite extends PlanTest with JoinSelectionHelper { test("getBroadcastBuildSide (hintOnly = false) return None when right has no broadcast hint") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB") { val broadcastSide = getBroadcastBuildSide( - left, - right, - Inner, - JoinHint(None, hintNotToBroadcast ), + join.copy(hint = JoinHint(None, hintNotToBroadcast)), hintOnly = false, SQLConf.get ) @@ -118,10 +102,7 @@ class JoinSelectionHelperSuite extends PlanTest with JoinSelectionHelper { test("getShuffleHashJoinBuildSide (hintOnly = true) return BuildLeft with only a left hint") { val broadcastSide = getShuffleHashJoinBuildSide( - left, - right, - Inner, - JoinHint(hintShuffleHash, None), + join.copy(hint = JoinHint(hintShuffleHash, None)), hintOnly = true, SQLConf.get ) @@ -130,10 +111,7 @@ class JoinSelectionHelperSuite extends PlanTest with JoinSelectionHelper { test("getShuffleHashJoinBuildSide (hintOnly = true) return BuildRight with only a right hint") { val broadcastSide = getShuffleHashJoinBuildSide( - left, - right, - Inner, - JoinHint(None, hintShuffleHash), + join.copy(hint = JoinHint(None, hintShuffleHash)), hintOnly = true, SQLConf.get ) @@ -142,10 +120,7 @@ class JoinSelectionHelperSuite extends PlanTest with JoinSelectionHelper { test("getShuffleHashJoinBuildSide (hintOnly = true) return smaller side when both have hints") { val broadcastSide = getShuffleHashJoinBuildSide( - left, - right, - Inner, - JoinHint(hintShuffleHash, hintShuffleHash), + join.copy(hint = JoinHint(hintShuffleHash, hintShuffleHash)), hintOnly = true, SQLConf.get ) @@ -154,10 +129,7 @@ class JoinSelectionHelperSuite extends PlanTest with JoinSelectionHelper { test("getShuffleHashJoinBuildSide (hintOnly = true) return None when no side has a hint") { val broadcastSide = getShuffleHashJoinBuildSide( - left, - right, - Inner, - JoinHint(None, None), + join.copy(hint = JoinHint(None, None)), hintOnly = true, SQLConf.get ) @@ -166,10 +138,7 @@ class JoinSelectionHelperSuite extends PlanTest with JoinSelectionHelper { test("getShuffleHashJoinBuildSide (hintOnly = false) return BuildRight when right is smaller") { val broadcastSide = getBroadcastBuildSide( - left, - right, - Inner, - JoinHint(None, None), + join.copy(hint = JoinHint(None, None)), hintOnly = false, SQLConf.get ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 348cc00a1f976..9e14d13b5cb1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -248,8 +248,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val hashJoinSupport = hashJoinSupported(leftKeys, rightKeys) def createBroadcastHashJoin(onlyLookingAtHint: Boolean) = { if (hashJoinSupport) { - val buildSide = getBroadcastBuildSide( - left, right, joinType, hint, onlyLookingAtHint, conf) + val buildSide = getBroadcastBuildSide(j, onlyLookingAtHint, conf) checkHintBuildSide(onlyLookingAtHint, buildSide, joinType, hint, true) buildSide.map { buildSide => @@ -269,8 +268,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def createShuffleHashJoin(onlyLookingAtHint: Boolean) = { if (hashJoinSupport) { - val buildSide = getShuffleHashJoinBuildSide( - left, right, joinType, hint, onlyLookingAtHint, conf) + val buildSide = getShuffleHashJoinBuildSide(j, onlyLookingAtHint, conf) checkHintBuildSide(onlyLookingAtHint, buildSide, joinType, hint, false) buildSide.map { buildSide => From c5b6ec734bd0c47551b59f9de13c6323b80974b2 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 10 May 2024 08:22:03 -0700 Subject: [PATCH 51/65] [SPARK-47441][YARN] Do not add log link for unmanaged AM in Spark UI ### What changes were proposed in this pull request? This PR makes it do not add log link for unmanaged AM in Spark UI. ### Why are the changes needed? Avoid start driver error messages: ``` 24/03/18 04:58:25,022 ERROR [spark-listener-group-appStatus] scheduler.AsyncEventQueue:97 : Listener AppStatusListener threw an exception java.lang.NumberFormatException: For input string: "null" at java.lang.NumberFormatException.forInputString(NumberFormatException.java:67) ~[?:?] at java.lang.Integer.parseInt(Integer.java:668) ~[?:?] at java.lang.Integer.parseInt(Integer.java:786) ~[?:?] at scala.collection.immutable.StringLike.toInt(StringLike.scala:310) ~[scala-library-2.12.18.jar:?] at scala.collection.immutable.StringLike.toInt$(StringLike.scala:310) ~[scala-library-2.12.18.jar:?] at scala.collection.immutable.StringOps.toInt(StringOps.scala:33) ~[scala-library-2.12.18.jar:?] at org.apache.spark.util.Utils$.parseHostPort(Utils.scala:1105) ~[spark-core_2.12-3.5.1.jar:3.5.1] at org.apache.spark.status.ProcessSummaryWrapper.(storeTypes.scala:609) ~[spark-core_2.12-3.5.1.jar:3.5.1] at org.apache.spark.status.LiveMiscellaneousProcess.doUpdate(LiveEntity.scala:1045) ~[spark-core_2.12-3.5.1.jar:3.5.1] at org.apache.spark.status.LiveEntity.write(LiveEntity.scala:50) ~[spark-core_2.12-3.5.1.jar:3.5.1] at org.apache.spark.status.AppStatusListener.update(AppStatusListener.scala:1233) ~[spark-core_2.12-3.5.1.jar:3.5.1] at org.apache.spark.status.AppStatusListener.onMiscellaneousProcessAdded(AppStatusListener.scala:1445) ~[spark-core_2.12-3.5.1.jar:3.5.1] at org.apache.spark.status.AppStatusListener.onOtherEvent(AppStatusListener.scala:113) ~[spark-core_2.12-3.5.1.jar:3.5.1] at org.apache.spark.scheduler.SparkListenerBus.doPostEvent(SparkListenerBus.scala:100) ~[spark-core_2.12-3.5.1.jar:3.5.1] at org.apache.spark.scheduler.SparkListenerBus.doPostEvent$(SparkListenerBus.scala:28) ~[spark-core_2.12-3.5.1.jar:3.5.1] at org.apache.spark.scheduler.AsyncEventQueue.doPostEvent(AsyncEventQueue.scala:37) ~[spark-core_2.12-3.5.1.jar:3.5.1] at org.apache.spark.scheduler.AsyncEventQueue.doPostEvent(AsyncEventQueue.scala:37) ~[spark-core_2.12-3.5.1.jar:3.5.1] at org.apache.spark.util.ListenerBus.postToAll(ListenerBus.scala:117) ~[spark-core_2.12-3.5.1.jar:3.5.1] at org.apache.spark.util.ListenerBus.postToAll$(ListenerBus.scala:101) ~[spark-core_2.12-3.5.1.jar:3.5.1] at org.apache.spark.scheduler.AsyncEventQueue.super$postToAll(AsyncEventQueue.scala:105) ~[spark-core_2.12-3.5.1.jar:3.5.1] at org.apache.spark.scheduler.AsyncEventQueue.$anonfun$dispatch$1(AsyncEventQueue.scala:105) ~[spark-core_2.12-3.5.1.jar:3.5.1] at scala.runtime.java8.JFunction0$mcJ$sp.apply(JFunction0$mcJ$sp.java:23) ~[scala-library-2.12.18.jar:?] at scala.util.DynamicVariable.withValue(DynamicVariable.scala:62) ~[scala-library-2.12.18.jar:?] at org.apache.spark.scheduler.AsyncEventQueue.org$apache$spark$scheduler$AsyncEventQueue$$dispatch(AsyncEventQueue.scala:100) ~[spark-core_2.12-3.5.1.jar:3.5.1] at org.apache.spark.scheduler.AsyncEventQueue$$anon$2.$anonfun$run$1(AsyncEventQueue.scala:96) ~[spark-core_2.12-3.5.1.jar:3.5.1] at org.apache.spark.util.Utils$.tryOrStopSparkContext(Utils.scala:1356) [spark-core_2.12-3.5.1.jar:3.5.1] at org.apache.spark.scheduler.AsyncEventQueue$$anon$2.run(AsyncEventQueue.scala:96) [spark-core_2.12-3.5.1.jar:3.5.1] ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manual testing: ```shell bin/spark-sql --master yarn --conf spark.yarn.unmanagedAM.enabled=true ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #45565 from wangyum/SPARK-47441. Authored-by: Yuming Wang Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/deploy/yarn/ApplicationMaster.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index dffb05e196d78..8f20f6602ec5c 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -793,9 +793,9 @@ private[spark] class ApplicationMaster( override def onStart(): Unit = { driver.send(RegisterClusterManager(self)) - // if deployment mode for yarn Application is client + // if deployment mode for yarn Application is managed client // then send the AM Log Info to spark driver - if (!isClusterMode) { + if (!isClusterMode && !sparkConf.get(YARN_UNMANAGED_AM)) { val hostPort = YarnContainerInfoHelper.getNodeManagerHttpAddress(None) val yarnAMID = "yarn-am" val info = new MiscellaneousProcessDetails(hostPort, From 5beaf85cd5ef2b84a67ebce712e8d73d1e7d41ff Mon Sep 17 00:00:00 2001 From: Chaoqin Li Date: Fri, 10 May 2024 08:24:42 -0700 Subject: [PATCH 52/65] [SPARK-47793][TEST][FOLLOWUP] Fix flaky test for Python data source exactly once ### What changes were proposed in this pull request? Fix the flakiness in python streaming source exactly once test. The last executed batch may not be recorded in query progress, which cause the expected rows doesn't match. This fix takes the uncompleted batch into account and relax the condition ### Why are the changes needed? Fix flaky test. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Test change. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46481 from chaoqin-li1123/fix_python_ds_test. Authored-by: Chaoqin Li Signed-off-by: Dongjoon Hyun --- .../python/PythonStreamingDataSourceSuite.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala index 97e6467c3eaf5..d1f7c597b308f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala @@ -299,7 +299,7 @@ class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { val checkpointDir = new File(path, "checkpoint") val outputDir = new File(path, "output") val df = spark.readStream.format(dataSourceName).load() - var lastBatch = 0 + var lastBatchId = 0 // Restart streaming query multiple times to verify exactly once guarantee. for (i <- 1 to 5) { @@ -323,11 +323,15 @@ class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { } q.stop() q.awaitTermination() - lastBatch = q.lastProgress.batchId.toInt + lastBatchId = q.lastProgress.batchId.toInt } - assert(lastBatch > 20) + assert(lastBatchId > 20) + val rowCount = spark.read.format("json").load(outputDir.getAbsolutePath).count() + // There may be one uncommitted batch that is not recorded in query progress. + // The number of batch can be lastBatchId + 1 or lastBatchId + 2. + assert(rowCount == 2 * (lastBatchId + 1) || rowCount == 2 * (lastBatchId + 2)) checkAnswer(spark.read.format("json").load(outputDir.getAbsolutePath), - (0 to 2 * lastBatch + 1).map(Row(_))) + (0 until rowCount.toInt).map(Row(_))) } } From a6632ffa16f6907eba96e745920d571924bf4b63 Mon Sep 17 00:00:00 2001 From: Vladimir Golubev Date: Sat, 11 May 2024 00:37:54 +0800 Subject: [PATCH 53/65] [SPARK-48143][SQL] Use lightweight exceptions for control-flow between UnivocityParser and FailureSafeParser # What changes were proposed in this pull request? New lightweight exception for control-flow between UnivocityParser and FalureSafeParser to speed-up malformed CSV parsing. This is a different way to implement these reverted changes: https://github.com/apache/spark/pull/46478 The previous implementation was more invasive - removing `cause` from `BadRecordException` could break upper code, which unwraps errors and checks the types of the causes. This implementation only touches `FailureSafeParser` and `UnivocityParser` since in the codebase they are always used together, unlike `JacksonParser` and `StaxXmlParser`. Removing stacktrace from `BadRecordException` is safe, since the cause itself has an adequate stacktrace (except pure control-flow cases). ### Why are the changes needed? Parsing in `PermissiveMode` is slow due to heavy exception construction (stacktrace filling + string template substitution in `SparkRuntimeException`) ### Does this PR introduce _any_ user-facing change? No, since `FailureSafeParser` unwraps `BadRecordException` and correctly rethrows user-facing exceptions in `FailFastMode` ### How was this patch tested? - `testOnly org.apache.spark.sql.catalyst.csv.UnivocityParserSuite` - Manually run csv benchmark - Manually checked correct and malformed csv in sherk-shell (org.apache.spark.SparkException is thrown with the stacktrace) ### Was this patch authored or co-authored using generative AI tooling? No Closes #46500 from vladimirg-db/vladimirg-db/use-special-lighweight-exception-for-control-flow-between-univocity-parser-and-failure-safe-parser. Authored-by: Vladimir Golubev Signed-off-by: Wenchen Fan --- .../sql/catalyst/csv/UnivocityParser.scala | 5 +++-- .../catalyst/util/BadRecordException.scala | 22 ++++++++++++++++--- .../sql/catalyst/util/FailureSafeParser.scala | 11 ++++++++-- 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index a5158d8a22c6b..4d95097e16816 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -316,7 +316,7 @@ class UnivocityParser( throw BadRecordException( () => getCurrentInput, () => Array.empty, - QueryExecutionErrors.malformedCSVRecordError("")) + LazyBadRecordCauseWrapper(() => QueryExecutionErrors.malformedCSVRecordError(""))) } val currentInput = getCurrentInput @@ -326,7 +326,8 @@ class UnivocityParser( // However, we still have chance to parse some of the tokens. It continues to parses the // tokens normally and sets null when `ArrayIndexOutOfBoundsException` occurs for missing // tokens. - Some(QueryExecutionErrors.malformedCSVRecordError(currentInput.toString)) + Some(LazyBadRecordCauseWrapper( + () => QueryExecutionErrors.malformedCSVRecordError(currentInput.toString))) } else None // When the length of the returned tokens is identical to the length of the parsed schema, // we just need to: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala index 65a56c1064e45..654b0b8c73e51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala @@ -67,16 +67,32 @@ case class PartialResultArrayException( extends Exception(cause) /** - * Exception thrown when the underlying parser meet a bad record and can't parse it. + * Exception thrown when the underlying parser met a bad record and can't parse it. + * The stacktrace is not collected for better preformance, and thus, this exception should + * not be used in a user-facing context. * @param record a function to return the record that cause the parser to fail * @param partialResults a function that returns an row array, which is the partial results of * parsing this bad record. - * @param cause the actual exception about why the record is bad and can't be parsed. + * @param cause the actual exception about why the record is bad and can't be parsed. It's better + * to use `LazyBadRecordCauseWrapper` here to delay heavy cause construction + * until it's needed. */ case class BadRecordException( @transient record: () => UTF8String, @transient partialResults: () => Array[InternalRow] = () => Array.empty[InternalRow], - cause: Throwable) extends Exception(cause) + cause: Throwable) extends Exception(cause) { + override def getStackTrace(): Array[StackTraceElement] = new Array[StackTraceElement](0) + override def fillInStackTrace(): Throwable = this +} + +/** + * Exception to use as `BadRecordException` cause to delay heavy user-facing exception construction. + * Does not contain stacktrace and used only for control flow + */ +case class LazyBadRecordCauseWrapper(cause: () => Throwable) extends Exception() { + override def getStackTrace(): Array[StackTraceElement] = new Array[StackTraceElement](0) + override def fillInStackTrace(): Throwable = this +} /** * Exception thrown when the underlying parser parses a JSON array as a struct. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala index 10cd159c769b2..d9946d1b12ec3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala @@ -78,10 +78,17 @@ class FailureSafeParser[IN]( case StringAsDataTypeException(fieldName, fieldValue, dataType) => throw QueryExecutionErrors.cannotParseStringAsDataTypeError(e.record().toString, fieldName, fieldValue, dataType) - case other => throw QueryExecutionErrors.malformedRecordsDetectedInRecordParsingError( - toResultRow(e.partialResults().headOption, e.record).toString, other) + case causeWrapper: LazyBadRecordCauseWrapper => + throwMalformedRecordsDetectedInRecordParsingError(e, causeWrapper.cause()) + case cause => throwMalformedRecordsDetectedInRecordParsingError(e, cause) } } } } + + private def throwMalformedRecordsDetectedInRecordParsingError( + e: BadRecordException, cause: Throwable): Nothing = { + throw QueryExecutionErrors.malformedRecordsDetectedInRecordParsingError( + toResultRow(e.partialResults().headOption, e.record).toString, cause) + } } From 2225aa1dab0fdb358ce032e07057a54aaf4e456f Mon Sep 17 00:00:00 2001 From: fred-db Date: Fri, 10 May 2024 11:02:19 -0700 Subject: [PATCH 54/65] [SPARK-48144][SQL] Fix `canPlanAsBroadcastHashJoin` to respect shuffle join hints ### What changes were proposed in this pull request? * Currently, `canPlanAsBroadcastHashJoin` incorrectly returns that a join can be planned as a BHJ, even though the join contains a SHJ. * To fix this, add some logic that checks whether the join contains a SHJ hint before checking if the join can be broadcasted. * Also made a small refactor to the `JoinSelectionHelperSuite` to make it a bit more readable. ### Why are the changes needed? * `canPlanAsBroadcastHashJoin` should be in sync with the join selection in `SparkStrategies`. Currently, it is not in sync. ### Does this PR introduce _any_ user-facing change? Yes, semi / anti joins that could not have been planned as broadcasts would now not be pushed through aggregates anymore. Generally, this would be a performance improvement. ### How was this patch tested? * Added UTs to check that a join with a SHJ hint is not marked as being planned as a BHJ. * Added tests to keep `canPlanAsBroadcastHashJoin` and the `JoinSelection` codepath in sync. ### Was this patch authored or co-authored using generative AI tooling? * No Closes #46401 from fred-db/fix-hint. Authored-by: fred-db Signed-off-by: Dongjoon Hyun --- .../spark/sql/catalyst/optimizer/joins.scala | 38 +++++++++++++++---- .../spark/sql/execution/SparkStrategies.scala | 17 --------- .../org/apache/spark/sql/JoinSuite.scala | 26 ++++++++++++- 3 files changed, 55 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 5571178832db7..9fc4873c248b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -20,15 +20,16 @@ package org.apache.spark.sql.catalyst.optimizer import scala.annotation.tailrec import scala.util.control.NonFatal -import org.apache.spark.internal.LogKeys.JOIN_CONDITION -import org.apache.spark.internal.MDC +import org.apache.spark.internal.{Logging, MDC} +import org.apache.spark.internal.LogKeys.{HASH_JOIN_KEYS, JOIN_CONDITION} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.planning.{ExtractEquiJoinKeys, ExtractFiltersAndInnerJoins} +import org.apache.spark.sql.catalyst.planning.{ExtractEquiJoinKeys, ExtractFiltersAndInnerJoins, ExtractSingleColumnNullAwareAntiJoin} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreePattern._ +import org.apache.spark.sql.catalyst.util.UnsafeRowUtils import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils @@ -286,7 +287,7 @@ case object BuildRight extends BuildSide case object BuildLeft extends BuildSide -trait JoinSelectionHelper { +trait JoinSelectionHelper extends Logging { def getBroadcastBuildSide( join: Join, @@ -394,9 +395,32 @@ trait JoinSelectionHelper { } } - def canPlanAsBroadcastHashJoin(join: Join, conf: SQLConf): Boolean = { - getBroadcastBuildSide(join, hintOnly = true, conf).isDefined || - getBroadcastBuildSide(join, hintOnly = false, conf).isDefined + protected def hashJoinSupported + (leftKeys: Seq[Expression], rightKeys: Seq[Expression]): Boolean = { + val result = leftKeys.concat(rightKeys).forall(e => UnsafeRowUtils.isBinaryStable(e.dataType)) + if (!result) { + val keysNotSupportingHashJoin = leftKeys.concat(rightKeys).filterNot( + e => UnsafeRowUtils.isBinaryStable(e.dataType)) + logWarning(log"Hash based joins are not supported due to joining on keys that don't " + + log"support binary equality. Keys not supporting hash joins: " + + log"${ + MDC(HASH_JOIN_KEYS, keysNotSupportingHashJoin.map( + e => e.toString + " due to DataType: " + e.dataType.typeName).mkString(", ")) + }") + } + result + } + + def canPlanAsBroadcastHashJoin(join: Join, conf: SQLConf): Boolean = join match { + case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, _, _, _) => + val hashJoinSupport = hashJoinSupported(leftKeys, rightKeys) + val noShufflePlannedBefore = + !hashJoinSupport || getShuffleHashJoinBuildSide(join, hintOnly = true, conf).isEmpty + getBroadcastBuildSide(join, hintOnly = true, conf).isDefined || + (noShufflePlannedBefore && + getBroadcastBuildSide(join, hintOnly = false, conf).isDefined) + case ExtractSingleColumnNullAwareAntiJoin(_, _) => true + case _ => false } def canPruneLeft(joinType: JoinType): Boolean = joinType match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 9e14d13b5cb1c..f0682e6b9afc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -20,8 +20,6 @@ package org.apache.spark.sql.execution import java.util.Locale import org.apache.spark.{SparkException, SparkUnsupportedOperationException} -import org.apache.spark.internal.LogKeys.HASH_JOIN_KEYS -import org.apache.spark.internal.MDC import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, AnalysisException, Strategy} import org.apache.spark.sql.catalyst.InternalRow @@ -33,7 +31,6 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes, StreamingRelationV2} import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.catalyst.util.UnsafeRowUtils import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.aggregate.AggUtils import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} @@ -208,20 +205,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - private def hashJoinSupported - (leftKeys: Seq[Expression], rightKeys: Seq[Expression]): Boolean = { - val result = leftKeys.concat(rightKeys).forall(e => UnsafeRowUtils.isBinaryStable(e.dataType)) - if (!result) { - val keysNotSupportingHashJoin = leftKeys.concat(rightKeys).filterNot( - e => UnsafeRowUtils.isBinaryStable(e.dataType)) - logWarning(log"Hash based joins are not supported due to joining on keys that don't " + - log"support binary equality. Keys not supporting hash joins: " + - log"${MDC(HASH_JOIN_KEYS, keysNotSupportingHashJoin.map( - e => e.toString + " due to DataType: " + e.dataType.typeName).mkString(", "))}") - } - result - } - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // If it is an equi-join, we first look at the join hints w.r.t. the following order: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index be6862f5b96b7..fcb937d82ba42 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.internal.config.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_T import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.{Ascending, GenericRow, SortOrder} -import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, JoinSelectionHelper} import org.apache.spark.sql.catalyst.plans.logical.{Filter, HintInfo, Join, JoinHint, NO_BROADCAST_AND_REPLICATION} import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -40,7 +40,8 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.tags.SlowSQLTest @SlowSQLTest -class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlanHelper { +class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlanHelper + with JoinSelectionHelper { import testImplicits._ setupTestData() @@ -61,6 +62,7 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan val sqlString = pair._1 val c = pair._2 val df = sql(sqlString) + val optimized = df.queryExecution.optimizedPlan val physical = df.queryExecution.sparkPlan val operators = physical.collect { case j: BroadcastHashJoinExec => j @@ -74,6 +76,10 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan if (operators.head.getClass != c) { fail(s"$sqlString expected operator: $c, but got ${operators.head}\n physical: \n$physical") } + assert( + canPlanAsBroadcastHashJoin(optimized.asInstanceOf[Join], conf) === + operators.head.isInstanceOf[BroadcastHashJoinExec], + "canPlanAsBroadcastHashJoin not in sync with join selection codepath!") operators.head } @@ -89,11 +95,13 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan val planned = spark.sessionState.planner.JoinSelection(join) assert(planned.size == 1) assert(planned.head.isInstanceOf[CartesianProductExec]) + assert(!canPlanAsBroadcastHashJoin(join, conf)) val plannedWithHint = spark.sessionState.planner.JoinSelection(joinWithHint) assert(plannedWithHint.size == 1) assert(plannedWithHint.head.isInstanceOf[BroadcastNestedLoopJoinExec]) assert(plannedWithHint.head.asInstanceOf[BroadcastNestedLoopJoinExec].buildSide == BuildLeft) + assert(!canPlanAsBroadcastHashJoin(joinWithHint, conf)) } } @@ -112,10 +120,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan val planned = spark.sessionState.planner.JoinSelection(join) assert(planned.size == 1) assert(planned.head.isInstanceOf[BroadcastHashJoinExec]) + assert(canPlanAsBroadcastHashJoin(join, conf)) val plannedWithHint = spark.sessionState.planner.JoinSelection(joinWithHint) assert(plannedWithHint.size == 1) assert(plannedWithHint.head.isInstanceOf[SortMergeJoinExec]) + assert(!canPlanAsBroadcastHashJoin(joinWithHint, conf)) } test("NO_BROADCAST_AND_REPLICATION controls build side in BNLJ") { @@ -131,11 +141,13 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan assert(planned.size == 1) assert(planned.head.isInstanceOf[BroadcastNestedLoopJoinExec]) assert(planned.head.asInstanceOf[BroadcastNestedLoopJoinExec].buildSide == BuildRight) + assert(!canPlanAsBroadcastHashJoin(join, conf)) val plannedWithHint = spark.sessionState.planner.JoinSelection(joinWithHint) assert(plannedWithHint.size == 1) assert(plannedWithHint.head.isInstanceOf[BroadcastNestedLoopJoinExec]) assert(plannedWithHint.head.asInstanceOf[BroadcastNestedLoopJoinExec].buildSide == BuildLeft) + assert(!canPlanAsBroadcastHashJoin(joinWithHint, conf)) } test("join operator selection") { @@ -191,6 +203,16 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan // ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } // } + test("broadcastable join with shuffle join hint") { + spark.sharedState.cacheManager.clearCache() + sql("CACHE TABLE testData") + // Make sure it's planned as broadcast join without the hint. + assertJoin("SELECT * FROM testData JOIN testData2 ON key = a", + classOf[BroadcastHashJoinExec]) + assertJoin("SELECT /*+ SHUFFLE_HASH(testData) */ * FROM testData JOIN testData2 ON key = a", + classOf[ShuffledHashJoinExec]) + } + test("broadcasted hash join operator selection") { spark.sharedState.cacheManager.clearCache() sql("CACHE TABLE testData") From 726ef8aa66ea6e56b739f3b16f99e457a0febb81 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 10 May 2024 15:34:12 -0700 Subject: [PATCH 55/65] Revert "[SPARK-48230][BUILD] Remove unused `jodd-core`" This reverts commit d8151186d79459fbde27a01bd97328e73548c55a. --- LICENSE-binary | 1 + dev/deps/spark-deps-hadoop-3-hive-2.3 | 1 + licenses-binary/LICENSE-jodd.txt | 24 ++++++++++++++++++++++++ pom.xml | 6 ++++++ sql/hive/pom.xml | 4 ++++ 5 files changed, 36 insertions(+) create mode 100644 licenses-binary/LICENSE-jodd.txt diff --git a/LICENSE-binary b/LICENSE-binary index 034215f0ab157..40271c9924bc4 100644 --- a/LICENSE-binary +++ b/LICENSE-binary @@ -436,6 +436,7 @@ com.esotericsoftware:reflectasm org.codehaus.janino:commons-compiler org.codehaus.janino:janino jline:jline +org.jodd:jodd-core com.github.wendykierp:JTransforms pl.edu.icm:JLargeArrays diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 29997815e5bc1..392bacd73277f 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -143,6 +143,7 @@ jline/2.14.6//jline-2.14.6.jar jline/3.24.1//jline-3.24.1.jar jna/5.13.0//jna-5.13.0.jar joda-time/2.12.7//joda-time-2.12.7.jar +jodd-core/3.5.2//jodd-core-3.5.2.jar jpam/1.1//jpam-1.1.jar json/1.8//json-1.8.jar json4s-ast_2.13/4.0.7//json4s-ast_2.13-4.0.7.jar diff --git a/licenses-binary/LICENSE-jodd.txt b/licenses-binary/LICENSE-jodd.txt new file mode 100644 index 0000000000000..cc6b458adb386 --- /dev/null +++ b/licenses-binary/LICENSE-jodd.txt @@ -0,0 +1,24 @@ +Copyright (c) 2003-present, Jodd Team (https://jodd.org) +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, +this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/pom.xml b/pom.xml index a98efe8aed1e6..56a34cedde516 100644 --- a/pom.xml +++ b/pom.xml @@ -201,6 +201,7 @@ 3.1.9 3.0.12 2.12.7 + 3.5.2 3.0.0 2.2.11 0.16.0 @@ -2782,6 +2783,11 @@ joda-time ${joda.version} + + org.jodd + jodd-core + ${jodd.version} + org.datanucleus datanucleus-core diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 5e9fc256e7e64..3895d9dc5a634 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -152,6 +152,10 @@ joda-time joda-time + + org.jodd + jodd-core + com.google.code.findbugs jsr305 From 5b3b8a90638c49fc7ddcace69a85989c1053f1ab Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 10 May 2024 15:48:08 -0700 Subject: [PATCH 56/65] [SPARK-48236][BUILD] Add `commons-lang:commons-lang:2.6` back to support legacy Hive UDF jars ### What changes were proposed in this pull request? This PR aims to add `commons-lang:commons-lang:2.6` back to support legacy Hive UDF jars . This is a partial revert of SPARK-47018 . ### Why are the changes needed? Recently, we dropped `commons-lang:commons-lang` during Hive upgrade. - #46468 However, only Apache Hive 2.3.10 or 4.0.0 dropped it. In other words, Hive 2.0.0 ~ 2.3.9 and Hive 3.0.0 ~ 3.1.3 requires it. As a result, all existing UDF jars built against those versions requires `commons-lang:commons-lang` still. - https://github.com/apache/hive/pull/4892 For example, Apache Hive 3.1.3 code: - https://github.com/apache/hive/blob/af7059e2bdc8b18af42e0b7f7163b923a0bfd424/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFTrim.java#L21 ``` import org.apache.commons.lang.StringUtils; ``` - https://github.com/apache/hive/blob/af7059e2bdc8b18af42e0b7f7163b923a0bfd424/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFTrim.java#L42 ``` return StringUtils.strip(val, " "); ``` As a result, Maven CIs are broken. - https://github.com/apache/spark/actions/runs/9032639456/job/24825599546 (Maven / Java 17) - https://github.com/apache/spark/actions/runs/9033374547/job/24835284769 (Maven / Java 21) The root cause is that the existing test UDF jar `hive-test-udfs.jar` was built from old Hive (before 2.3.10) libraries which requires `commons-lang:commons-lang:2.6`. ``` HiveUDFDynamicLoadSuite: - Spark should be able to run Hive UDF using jar regardless of current thread context classloader (UDF 20:21:25.129 WARN org.apache.spark.SparkContext: The JAR file:///home/runner/work/spark/spark/sql/hive/src/test/noclasspath/hive-test-udfs.jar at spark://localhost:33327/jars/hive-test-udfs.jar has been added already. Overwriting of added jar is not supported in the current version. *** RUN ABORTED *** A needed class was not found. This could be due to an error in your runpath. Missing class: org/apache/commons/lang/StringUtils java.lang.NoClassDefFoundError: org/apache/commons/lang/StringUtils at org.apache.hadoop.hive.contrib.udf.example.GenericUDFTrim2.performOp(GenericUDFTrim2.java:43) at org.apache.hadoop.hive.ql.udf.generic.GenericUDFBaseTrim.evaluate(GenericUDFBaseTrim.java:75) at org.apache.hadoop.hive.ql.udf.generic.GenericUDF.initializeAndFoldConstants(GenericUDF.java:170) at org.apache.spark.sql.hive.HiveGenericUDFEvaluator.returnInspector$lzycompute(hiveUDFEvaluators.scala:118) at org.apache.spark.sql.hive.HiveGenericUDFEvaluator.returnInspector(hiveUDFEvaluators.scala:117) at org.apache.spark.sql.hive.HiveGenericUDF.dataType$lzycompute(hiveUDFs.scala:132) at org.apache.spark.sql.hive.HiveGenericUDF.dataType(hiveUDFs.scala:132) at org.apache.spark.sql.hive.HiveUDFExpressionBuilder$.makeHiveFunctionExpression(HiveSessionStateBuilder.scala:184) at org.apache.spark.sql.hive.HiveUDFExpressionBuilder$.$anonfun$makeExpression$1(HiveSessionStateBuilder.scala:164) at org.apache.spark.util.Utils$.withContextClassLoader(Utils.scala:185) ... Cause: java.lang.ClassNotFoundException: org.apache.commons.lang.StringUtils at java.base/java.net.URLClassLoader.findClass(URLClassLoader.java:445) at java.base/java.lang.ClassLoader.loadClass(ClassLoader.java:593) at java.base/java.lang.ClassLoader.loadClass(ClassLoader.java:526) at org.apache.hadoop.hive.contrib.udf.example.GenericUDFTrim2.performOp(GenericUDFTrim2.java:43) at org.apache.hadoop.hive.ql.udf.generic.GenericUDFBaseTrim.evaluate(GenericUDFBaseTrim.java:75) at org.apache.hadoop.hive.ql.udf.generic.GenericUDF.initializeAndFoldConstants(GenericUDF.java:170) at org.apache.spark.sql.hive.HiveGenericUDFEvaluator.returnInspector$lzycompute(hiveUDFEvaluators.scala:118) at org.apache.spark.sql.hive.HiveGenericUDFEvaluator.returnInspector(hiveUDFEvaluators.scala:117) at org.apache.spark.sql.hive.HiveGenericUDF.dataType$lzycompute(hiveUDFs.scala:132) at org.apache.spark.sql.hive.HiveGenericUDF.dataType(hiveUDFs.scala:132) ... ``` ### Does this PR introduce _any_ user-facing change? To support the existing customer UDF jars. ### How was this patch tested? Manually. ``` $ build/mvn -Dtest=none -DwildcardSuites=org.apache.spark.sql.hive.HiveUDFDynamicLoadSuite test ... HiveUDFDynamicLoadSuite: 14:21:56.034 WARN org.apache.hadoop.hive.metastore.ObjectStore: Version information not found in metastore. hive.metastore.schema.verification is not enabled so recording the schema version 2.3.0 14:21:56.035 WARN org.apache.hadoop.hive.metastore.ObjectStore: setMetaStoreSchemaVersion called but recording version is disabled: version = 2.3.0, comment = Set by MetaStore dongjoon127.0.0.1 14:21:56.041 WARN org.apache.hadoop.hive.metastore.ObjectStore: Failed to get database default, returning NoSuchObjectException - Spark should be able to run Hive UDF using jar regardless of current thread context classloader (UDF 14:21:57.576 WARN org.apache.spark.SparkContext: The JAR file:///Users/dongjoon/APACHE/spark-merge/sql/hive/src/test/noclasspath/hive-test-udfs.jar at spark://localhost:55526/jars/hive-test-udfs.jar has been added already. Overwriting of added jar is not supported in the current version. - Spark should be able to run Hive UDF using jar regardless of current thread context classloader (GENERIC_UDF 14:21:58.314 WARN org.apache.spark.SparkContext: The JAR file:///Users/dongjoon/APACHE/spark-merge/sql/hive/src/test/noclasspath/hive-test-udfs.jar at spark://localhost:55526/jars/hive-test-udfs.jar has been added already. Overwriting of added jar is not supported in the current version. - Spark should be able to run Hive UDF using jar regardless of current thread context classloader (GENERIC_UDAF 14:21:58.943 WARN org.apache.spark.SparkContext: The JAR file:///Users/dongjoon/APACHE/spark-merge/sql/hive/src/test/noclasspath/hive-test-udfs.jar at spark://localhost:55526/jars/hive-test-udfs.jar has been added already. Overwriting of added jar is not supported in the current version. - Spark should be able to run Hive UDF using jar regardless of current thread context classloader (UDAF 14:21:59.333 WARN org.apache.hadoop.hive.ql.session.SessionState: METASTORE_FILTER_HOOK will be ignored, since hive.security.authorization.manager is set to instance of HiveAuthorizerFactory. 14:21:59.364 WARN org.apache.hadoop.hive.conf.HiveConf: HiveConf of name hive.internal.ss.authz.settings.applied.marker does not exist 14:21:59.370 WARN org.apache.hadoop.hive.metastore.HiveMetaStore: Location: file:/Users/dongjoon/APACHE/spark-merge/sql/hive/target/tmp/warehouse-49291492-9d48-4360-a354-ace73a2c76ce/src specified for non-external table:src 14:21:59.718 WARN org.apache.hadoop.hive.metastore.ObjectStore: Failed to get database global_temp, returning NoSuchObjectException 14:21:59.770 WARN org.apache.spark.SparkContext: The JAR file:///Users/dongjoon/APACHE/spark-merge/sql/hive/src/test/noclasspath/hive-test-udfs.jar at spark://localhost:55526/jars/hive-test-udfs.jar has been added already. Overwriting of added jar is not supported in the current version. - Spark should be able to run Hive UDF using jar regardless of current thread context classloader (GENERIC_UDTF 14:22:00.403 WARN org.apache.hadoop.hive.common.FileUtils: File file:/Users/dongjoon/APACHE/spark-merge/sql/hive/target/tmp/warehouse-49291492-9d48-4360-a354-ace73a2c76ce/src does not exist; Force to delete it. 14:22:00.404 ERROR org.apache.hadoop.hive.common.FileUtils: Failed to delete file:/Users/dongjoon/APACHE/spark-merge/sql/hive/target/tmp/warehouse-49291492-9d48-4360-a354-ace73a2c76ce/src 14:22:00.441 WARN org.apache.hadoop.hive.conf.HiveConf: HiveConf of name hive.internal.ss.authz.settings.applied.marker does not exist 14:22:00.453 WARN org.apache.hadoop.hive.ql.session.SessionState: METASTORE_FILTER_HOOK will be ignored, since hive.security.authorization.manager is set to instance of HiveAuthorizerFactory. 14:22:00.537 WARN org.apache.hadoop.hive.conf.HiveConf: HiveConf of name hive.internal.ss.authz.settings.applied.marker does not exist Run completed in 8 seconds, 612 milliseconds. Total number of tests run: 5 Suites: completed 2, aborted 0 Tests: succeeded 5, failed 0, canceled 0, ignored 0, pending 0 All tests passed. ``` ### Was this patch authored or co-authored using generative AI tooling? Closes #46528 from dongjoon-hyun/SPARK-48236. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- connector/kafka-0-10-assembly/pom.xml | 5 +++++ connector/kinesis-asl-assembly/pom.xml | 5 +++++ dev/deps/spark-deps-hadoop-3-hive-2.3 | 1 + pom.xml | 13 +++++++++++++ sql/hive/pom.xml | 4 ++++ 5 files changed, 28 insertions(+) diff --git a/connector/kafka-0-10-assembly/pom.xml b/connector/kafka-0-10-assembly/pom.xml index bd311b3a98047..b2fcbdf8eca7d 100644 --- a/connector/kafka-0-10-assembly/pom.xml +++ b/connector/kafka-0-10-assembly/pom.xml @@ -54,6 +54,11 @@ commons-codec provided + + commons-lang + commons-lang + provided + com.google.protobuf protobuf-java diff --git a/connector/kinesis-asl-assembly/pom.xml b/connector/kinesis-asl-assembly/pom.xml index 0e93526fce721..577ec21530837 100644 --- a/connector/kinesis-asl-assembly/pom.xml +++ b/connector/kinesis-asl-assembly/pom.xml @@ -54,6 +54,11 @@ jackson-databind provided + + commons-lang + commons-lang + provided + org.glassfish.jersey.core jersey-client diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 392bacd73277f..2b444dddcbe99 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -46,6 +46,7 @@ commons-compress/1.26.1//commons-compress-1.26.1.jar commons-crypto/1.1.0//commons-crypto-1.1.0.jar commons-dbcp/1.4//commons-dbcp-1.4.jar commons-io/2.16.1//commons-io-2.16.1.jar +commons-lang/2.6//commons-lang-2.6.jar commons-lang3/3.14.0//commons-lang3-3.14.0.jar commons-math3/3.6.1//commons-math3-3.6.1.jar commons-pool/1.5.4//commons-pool-1.5.4.jar diff --git a/pom.xml b/pom.xml index 56a34cedde516..ad6e9391b68cb 100644 --- a/pom.xml +++ b/pom.xml @@ -192,6 +192,8 @@ 1.17.0 1.26.1 2.16.1 + + 2.6 3.14.0 @@ -613,6 +615,11 @@ commons-text 1.12.0 + + commons-lang + commons-lang + ${commons-lang2.version} + commons-io commons-io @@ -2899,6 +2906,12 @@ hive-storage-api ${hive.storage.version} ${hive.storage.scope} + + + commons-lang + commons-lang + + commons-cli diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 3895d9dc5a634..56cad7f2b1df1 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -40,6 +40,10 @@ spark-core_${scala.binary.version} ${project.version} + + commons-lang + commons-lang + org.apache.spark spark-core_${scala.binary.version} From d82458f15539eef8df320345a7c2382ca4d5be8a Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Fri, 10 May 2024 16:31:47 -0700 Subject: [PATCH 57/65] [SPARK-48205][SQL][FOLLOWUP] Add missing tags for the dataSource API ### What changes were proposed in this pull request? This is a follow-up PR for https://github.com/apache/spark/pull/46487 to add missing tags for the `dataSource` API. ### Why are the changes needed? To address comments from a previous PR. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing test ### Was this patch authored or co-authored using generative AI tooling? No Closes #46530 from allisonwang-db/spark-48205-followup. Authored-by: allisonwang-db Signed-off-by: Dongjoon Hyun --- .../src/main/scala/org/apache/spark/sql/SparkSession.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index d5de74455dceb..466e4cf813185 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -233,7 +233,11 @@ class SparkSession private( /** * A collection of methods for registering user-defined data sources. + * + * @since 4.0.0 */ + @Experimental + @Unstable def dataSource: DataSourceRegistration = sessionState.dataSourceRegistration /** From f699f556d8a09bb755e9c8558661a36fbdb42e73 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Fri, 10 May 2024 19:54:29 -0700 Subject: [PATCH 58/65] [SPARK-48237][BUILD] Clean up `dev/pr-deps` at the end of `test-dependencies.sh` script ### What changes were proposed in this pull request? The pr aims to delete the dir `dev/pr-deps` after executing `test-dependencies.sh`. ### Why are the changes needed? We'd better clean the `temporary files` generated at the end. Before: ``` sh dev/test-dependencies.sh ``` image After: ``` sh dev/test-dependencies.sh ``` image ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manually test. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46531 from panbingkun/minor_test-dependencies. Authored-by: panbingkun Signed-off-by: Dongjoon Hyun --- dev/test-dependencies.sh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index 048c59f4cec9b..e645a66165a20 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -140,4 +140,8 @@ for HADOOP_HIVE_PROFILE in "${HADOOP_HIVE_PROFILES[@]}"; do fi done +if [[ -d "$FWDIR/dev/pr-deps" ]]; then + rm -rf "$FWDIR/dev/pr-deps" +fi + exit 0 From 57b207774382e3a35345518ede5cfc028885f90b Mon Sep 17 00:00:00 2001 From: panbingkun Date: Sat, 11 May 2024 21:41:14 +0900 Subject: [PATCH 59/65] [SPARK-48240][DOCS] Replace `Local[..]` with `"Local[...]"` in the docs ### What changes were proposed in this pull request? The pr aims to replace `Local[..]` with `"Local[...]"` in the docs ### Why are the changes needed? 1.When I recently switched from `bash` to `zsh` and executed command `./bin/spark-shell --master local[8]` on local, the following error will be printed: image 2.Some descriptions in the existing documents have been written as `--master "local[n]"`, eg: https://github.com/apache/spark/blob/f699f556d8a09bb755e9c8558661a36fbdb42e73/docs/index.md?plain=1#L49 3.The root cause is: https://blog.peiyingchi.com/2017/03/20/spark-zsh-no-matches-found-local/ image ### Does this PR introduce _any_ user-facing change? Yes, with the `zsh` becoming the mainstream of shell, avoid the confusion of spark users when submitting apps with `./bin/spark-shell --master "local[n]" ...` or `./bin/spark-sql --master "local[n]" ...`, etc ### How was this patch tested? Manually test Whether the user uses `bash` or `zsh`, the above `--master "local[n]"` can be executed successfully in the expected way. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46535 from panbingkun/SPARK-48240. Authored-by: panbingkun Signed-off-by: Hyukjin Kwon --- docs/configuration.md | 4 ++-- docs/quick-start.md | 6 +++--- docs/rdd-programming-guide.md | 12 ++++++------ docs/submitting-applications.md | 2 +- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index c018b9f1fb7c0..7884a2af60b23 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -91,7 +91,7 @@ Then, you can supply configuration values at runtime: ```sh ./bin/spark-submit \ --name "My app" \ - --master local[4] \ + --master "local[4]" \ --conf spark.eventLog.enabled=false \ --conf "spark.executor.extraJavaOptions=-XX:+PrintGCDetails -XX:+PrintGCTimeStamps" \ myApp.jar @@ -3750,7 +3750,7 @@ Also, you can modify or add configurations at runtime: {% highlight bash %} ./bin/spark-submit \ --name "My app" \ - --master local[4] \ + --master "local[4]" \ --conf spark.eventLog.enabled=false \ --conf "spark.executor.extraJavaOptions=-XX:+PrintGCDetails -XX:+PrintGCTimeStamps" \ --conf spark.hadoop.abc.def=xyz \ diff --git a/docs/quick-start.md b/docs/quick-start.md index 366970cf66c71..5a03af98cd832 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -286,7 +286,7 @@ We can run this application using the `bin/spark-submit` script: {% highlight bash %} # Use spark-submit to run your application $ YOUR_SPARK_HOME/bin/spark-submit \ - --master local[4] \ + --master "local[4]" \ SimpleApp.py ... Lines with a: 46, Lines with b: 23 @@ -371,7 +371,7 @@ $ sbt package # Use spark-submit to run your application $ YOUR_SPARK_HOME/bin/spark-submit \ --class "SimpleApp" \ - --master local[4] \ + --master "local[4]" \ target/scala-{{site.SCALA_BINARY_VERSION}}/simple-project_{{site.SCALA_BINARY_VERSION}}-1.0.jar ... Lines with a: 46, Lines with b: 23 @@ -452,7 +452,7 @@ $ mvn package # Use spark-submit to run your application $ YOUR_SPARK_HOME/bin/spark-submit \ --class "SimpleApp" \ - --master local[4] \ + --master "local[4]" \ target/simple-project-1.0.jar ... Lines with a: 46, Lines with b: 23 diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md index f75bda0ffafb0..cbbce4c082060 100644 --- a/docs/rdd-programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -214,13 +214,13 @@ can be passed to the `--repositories` argument. For example, to run `bin/pyspark` on exactly four cores, use: {% highlight bash %} -$ ./bin/pyspark --master local[4] +$ ./bin/pyspark --master "local[4]" {% endhighlight %} Or, to also add `code.py` to the search path (in order to later be able to `import code`), use: {% highlight bash %} -$ ./bin/pyspark --master local[4] --py-files code.py +$ ./bin/pyspark --master "local[4]" --py-files code.py {% endhighlight %} For a complete list of options, run `pyspark --help`. Behind the scenes, @@ -260,19 +260,19 @@ can be passed to the `--repositories` argument. For example, to run `bin/spark-s four cores, use: {% highlight bash %} -$ ./bin/spark-shell --master local[4] +$ ./bin/spark-shell --master "local[4]" {% endhighlight %} Or, to also add `code.jar` to its classpath, use: {% highlight bash %} -$ ./bin/spark-shell --master local[4] --jars code.jar +$ ./bin/spark-shell --master "local[4]" --jars code.jar {% endhighlight %} To include a dependency using Maven coordinates: {% highlight bash %} -$ ./bin/spark-shell --master local[4] --packages "org.example:example:0.1" +$ ./bin/spark-shell --master "local[4]" --packages "org.example:example:0.1" {% endhighlight %} For a complete list of options, run `spark-shell --help`. Behind the scenes, @@ -781,7 +781,7 @@ One of the harder things about Spark is understanding the scope and life cycle o #### Example -Consider the naive RDD element sum below, which may behave differently depending on whether execution is happening within the same JVM. A common example of this is when running Spark in `local` mode (`--master = local[n]`) versus deploying a Spark application to a cluster (e.g. via spark-submit to YARN): +Consider the naive RDD element sum below, which may behave differently depending on whether execution is happening within the same JVM. A common example of this is when running Spark in `local` mode (`--master = "local[n]"`) versus deploying a Spark application to a cluster (e.g. via spark-submit to YARN):
    diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index bf02ec137e200..3a99151768a12 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -91,7 +91,7 @@ run it with `--help`. Here are a few examples of common options: # Run application locally on 8 cores ./bin/spark-submit \ --class org.apache.spark.examples.SparkPi \ - --master local[8] \ + --master "local[8]" \ /path/to/examples.jar \ 100 From 5b965f70c057cb478896feea2456fc59267596df Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 13 May 2024 08:26:52 +0900 Subject: [PATCH 60/65] [SPARK-48239][INFRA] Update the release docker image to follow what we use in Github Action jobs ### What changes were proposed in this pull request? We have Github Action jobs to test package building and doc generation, but the execution environment is different from what we use for the release process. This PR updates the release docker image to follow what we use in Github Action: https://github.com/apache/spark/blob/master/dev/infra/Dockerfile Note: it's not exactly the same, as I have to do some modification to make it usable for the release process. In the future we should have a better way to unify these two docker files. ### Why are the changes needed? to make us be able to release ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? manually ### Was this patch authored or co-authored using generative AI tooling? no Closes #46534 from cloud-fan/re. Authored-by: Wenchen Fan Signed-off-by: Hyukjin Kwon --- dev/create-release/release-build.sh | 3 + dev/create-release/spark-rm/Dockerfile | 170 ++++++++++++++++--------- 2 files changed, 112 insertions(+), 61 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index b720a8fc93861..0fb16aafcbaad 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -80,6 +80,9 @@ done export LC_ALL=C.UTF-8 export LANG=C.UTF-8 +export PYSPARK_PYTHON=/usr/local/bin/python +export PYSPARK_DRIVER_PYTHON=/usr/local/bin/python + # Commit ref to checkout when building GIT_REF=${GIT_REF:-master} diff --git a/dev/create-release/spark-rm/Dockerfile b/dev/create-release/spark-rm/Dockerfile index 13f4112ca03da..adaa4df3f5791 100644 --- a/dev/create-release/spark-rm/Dockerfile +++ b/dev/create-release/spark-rm/Dockerfile @@ -15,74 +15,122 @@ # limitations under the License. # -# Image for building Spark releases. Based on Ubuntu 20.04. -# -# Includes: -# * Java 17 -# * Ivy -# * Python (3.8.5) -# * R-base/R-base-dev (4.0.3) -# * Ruby (2.7.0) -# -# You can test it as below: -# cd dev/create-release/spark-rm -# docker build -t spark-rm --build-arg UID=$UID . +# Image for building Spark releases. Based on Ubuntu 22.04. +FROM ubuntu:jammy-20240227 -FROM ubuntu:20.04 +ENV FULL_REFRESH_DATE 20240318 -# For apt to be noninteractive ENV DEBIAN_FRONTEND noninteractive ENV DEBCONF_NONINTERACTIVE_SEEN true -# These arguments are just for reuse and not really meant to be customized. -ARG APT_INSTALL="apt-get install --no-install-recommends -y" +RUN apt-get update && apt-get install -y \ + build-essential \ + ca-certificates \ + curl \ + gfortran \ + git \ + subversion \ + gnupg \ + libcurl4-openssl-dev \ + libfontconfig1-dev \ + libfreetype6-dev \ + libfribidi-dev \ + libgit2-dev \ + libharfbuzz-dev \ + libjpeg-dev \ + liblapack-dev \ + libopenblas-dev \ + libpng-dev \ + libpython3-dev \ + libssl-dev \ + libtiff5-dev \ + libxml2-dev \ + nodejs \ + npm \ + openjdk-17-jdk-headless \ + pandoc \ + pkg-config \ + python3.10 \ + python3-psutil \ + texlive-latex-base \ + texlive \ + texlive-fonts-extra \ + texinfo \ + texlive-latex-extra \ + qpdf \ + r-base \ + ruby \ + ruby-dev \ + software-properties-common \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* -ARG PIP_PKGS="sphinx==4.5.0 mkdocs==1.1.2 numpy==1.20.3 pydata_sphinx_theme==0.13.3 ipython==7.19.0 nbsphinx==0.8.0 numpydoc==1.1.0 jinja2==3.1.2 twine==3.4.1 sphinx-plotly-directive==0.1.3 sphinx-copybutton==0.5.2 pandas==2.0.3 pyarrow==10.0.1 plotly==5.4.0 markupsafe==2.0.1 docutils<0.17 grpcio==1.62.0 protobuf==4.21.6 grpcio-status==1.62.0 googleapis-common-protos==1.56.4" -ARG GEM_PKGS="bundler:2.4.22" -# Install extra needed repos and refresh. -# - CRAN repo -# - Ruby repo (for doc generation) -# -# This is all in a single "RUN" command so that if anything changes, "apt update" is run to fetch -# the most current package versions (instead of potentially using old versions cached by docker). -RUN apt-get clean && apt-get update && $APT_INSTALL gnupg ca-certificates && \ - echo 'deb https://cloud.r-project.org/bin/linux/ubuntu focal-cran40/' >> /etc/apt/sources.list && \ - gpg --keyserver hkps://keyserver.ubuntu.com --recv-key E298A3A825C0D65DFD57CBB651716619E084DAB9 && \ - gpg -a --export E084DAB9 | apt-key add - && \ - apt-get clean && \ - rm -rf /var/lib/apt/lists/* && \ - apt-get clean && \ - apt-get update && \ - $APT_INSTALL software-properties-common && \ - apt-get update && \ - # Install openjdk 17. - $APT_INSTALL openjdk-17-jdk && \ - update-alternatives --set java $(ls /usr/lib/jvm/java-17-openjdk-*/bin/java) && \ - # Install build / source control tools - $APT_INSTALL curl wget git maven ivy subversion make gcc lsof libffi-dev \ - pandoc pandoc-citeproc libssl-dev libcurl4-openssl-dev libxml2-dev && \ - curl -sL https://deb.nodesource.com/setup_12.x | bash && \ - $APT_INSTALL nodejs && \ - # Install needed python packages. Use pip for installing packages (for consistency). - $APT_INSTALL python-is-python3 python3-pip python3-setuptools && \ - # qpdf is required for CRAN checks to pass. - $APT_INSTALL qpdf jq && \ - pip3 install $PIP_PKGS && \ - # Install R packages and dependencies used when building. - # R depends on pandoc*, libssl (which are installed above). - # Note that PySpark doc generation also needs pandoc due to nbsphinx - $APT_INSTALL r-base r-base-dev && \ - $APT_INSTALL libcurl4-openssl-dev libgit2-dev libssl-dev libxml2-dev && \ - $APT_INSTALL texlive-latex-base texlive texlive-fonts-extra texinfo qpdf texlive-latex-extra && \ - $APT_INSTALL libfontconfig1-dev libharfbuzz-dev libfribidi-dev libfreetype6-dev libpng-dev libtiff5-dev libjpeg-dev && \ - Rscript -e "install.packages(c('curl', 'xml2', 'httr', 'devtools', 'testthat', 'knitr', 'rmarkdown', 'markdown', 'roxygen2', 'e1071', 'survival'), repos='https://cloud.r-project.org/')" && \ - Rscript -e "devtools::install_github('jimhester/lintr')" && \ - Rscript -e "devtools::install_version('pkgdown', version='2.0.1', repos='https://cloud.r-project.org')" && \ - Rscript -e "devtools::install_version('preferably', version='0.4', repos='https://cloud.r-project.org')" && \ - # Install tools needed to build the documentation. - $APT_INSTALL ruby2.7 ruby2.7-dev && \ - gem install --no-document $GEM_PKGS +RUN echo 'deb https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/' >> /etc/apt/sources.list +RUN gpg --keyserver hkps://keyserver.ubuntu.com --recv-key E298A3A825C0D65DFD57CBB651716619E084DAB9 +RUN gpg -a --export E084DAB9 | apt-key add - +RUN add-apt-repository 'deb https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/' + +# See more in SPARK-39959, roxygen2 < 7.2.1 +RUN Rscript -e "install.packages(c('devtools', 'knitr', 'markdown', \ + 'rmarkdown', 'testthat', 'devtools', 'e1071', 'survival', 'arrow', \ + 'ggplot2', 'mvtnorm', 'statmod', 'xml2'), repos='https://cloud.r-project.org/')" && \ + Rscript -e "devtools::install_version('roxygen2', version='7.2.0', repos='https://cloud.r-project.org')" && \ + Rscript -e "devtools::install_version('lintr', version='2.0.1', repos='https://cloud.r-project.org')" && \ + Rscript -e "devtools::install_version('pkgdown', version='2.0.1', repos='https://cloud.r-project.org')" && \ + Rscript -e "devtools::install_version('preferably', version='0.4', repos='https://cloud.r-project.org')" + +# See more in SPARK-39735 +ENV R_LIBS_SITE "/usr/local/lib/R/site-library:${R_LIBS_SITE}:/usr/lib/R/library" + + +RUN add-apt-repository ppa:pypy/ppa +RUN mkdir -p /usr/local/pypy/pypy3.9 && \ + curl -sqL https://downloads.python.org/pypy/pypy3.9-v7.3.16-linux64.tar.bz2 | tar xjf - -C /usr/local/pypy/pypy3.9 --strip-components=1 && \ + ln -sf /usr/local/pypy/pypy3.9/bin/pypy /usr/local/bin/pypy3.8 && \ + ln -sf /usr/local/pypy/pypy3.9/bin/pypy /usr/local/bin/pypy3 +RUN curl -sS https://bootstrap.pypa.io/get-pip.py | pypy3 +RUN pypy3 -m pip install numpy 'six==1.16.0' 'pandas==2.2.2' scipy coverage matplotlib lxml + + +ARG BASIC_PIP_PKGS="numpy pyarrow>=15.0.0 six==1.16.0 pandas==2.2.2 scipy plotly>=4.8 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2" +# Python deps for Spark Connect +ARG CONNECT_PIP_PKGS="grpcio==1.62.0 grpcio-status==1.62.0 protobuf==4.25.1 googleapis-common-protos==1.56.4" + +# Install Python 3.10 packages +RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10 +RUN python3.10 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this +RUN python3.10 -m pip install --ignore-installed 'six==1.16.0' # Avoid `python3-six` installation +RUN python3.10 -m pip install $BASIC_PIP_PKGS unittest-xml-reporting $CONNECT_PIP_PKGS && \ + python3.10 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu && \ + python3.10 -m pip install deepspeed torcheval && \ + python3.10 -m pip cache purge + +# Install Python 3.9 +RUN add-apt-repository ppa:deadsnakes/ppa +RUN apt-get update && apt-get install -y \ + python3.9 python3.9-distutils \ + && rm -rf /var/lib/apt/lists/* +RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.9 +RUN python3.9 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this +RUN python3.9 -m pip install --force $BASIC_PIP_PKGS unittest-xml-reporting $CONNECT_PIP_PKGS && \ + python3.9 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu && \ + python3.9 -m pip install torcheval && \ + python3.9 -m pip cache purge + +# Should unpin 'sphinxcontrib-*' after upgrading sphinx>5 +# See 'ipython_genutils' in SPARK-38517 +# See 'docutils<0.18.0' in SPARK-39421 +RUN python3.9 -m pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \ +ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8' 'docutils<0.18.0' \ +'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'black==23.9.1' \ +'pandas-stubs==1.2.0.53' 'grpcio==1.62.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \ +'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5' +RUN python3.9 -m pip list + +RUN gem install --no-document "bundler:2.4.22" +RUN ln -s "$(which python3.9)" "/usr/local/bin/python" WORKDIR /opt/spark-rm/output From b5584221cfc2d3cb052c082d8a94b4a00ccf4ed4 Mon Sep 17 00:00:00 2001 From: Vladimir Golubev Date: Mon, 13 May 2024 08:45:11 +0900 Subject: [PATCH 61/65] [SPARK-48245][SQL] Fix typo in BadRecordException class doc ### What changes were proposed in this pull request? Fix typo in `BadRecordException` class doc ### Why are the changes needed? To avoid annoyance ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? N/A ### Was this patch authored or co-authored using generative AI tooling? No Closes #46542 from vladimirg-db/vladimirg-db/fix-typo-in-bad-record-exception-doc. Authored-by: Vladimir Golubev Signed-off-by: Hyukjin Kwon --- .../org/apache/spark/sql/catalyst/util/BadRecordException.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala index 654b0b8c73e51..4fa6a2275e743 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala @@ -68,7 +68,7 @@ case class PartialResultArrayException( /** * Exception thrown when the underlying parser met a bad record and can't parse it. - * The stacktrace is not collected for better preformance, and thus, this exception should + * The stacktrace is not collected for better performance, and thus, this exception should * not be used in a user-facing context. * @param record a function to return the record that cause the parser to fail * @param partialResults a function that returns an row array, which is the partial results of From cae2248bc13d8bde7c48a1d7479df68bcd31fbf1 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 13 May 2024 11:09:44 +0800 Subject: [PATCH 62/65] [MINOR][PYTHON][TESTS] Move test `test_named_arguments_negative` to `test_arrow_python_udf` ### What changes were proposed in this pull request? Move test `test_named_arguments_negative` to `test_arrow_python_udf` ### Why are the changes needed? it seems was added in a wrong place, it only runs in Spark Connect, not Spark Classic. After this PR, it will also be run in Spark Classic ### Does this PR introduce _any_ user-facing change? no, test only ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #46544 from zhengruifeng/move_test_named_arguments_negative. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- .../connect/test_parity_arrow_python_udf.py | 26 ------------------- .../sql/tests/test_arrow_python_udf.py | 24 ++++++++++++++++- 2 files changed, 23 insertions(+), 27 deletions(-) diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py b/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py index fa329b598d98b..732008eb05a35 100644 --- a/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py +++ b/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py @@ -15,10 +15,6 @@ # limitations under the License. # -import unittest - -from pyspark.errors import AnalysisException, PythonException -from pyspark.sql.functions import udf from pyspark.sql.tests.connect.test_parity_udf import UDFParityTests from pyspark.sql.tests.test_arrow_python_udf import PythonUDFArrowTestsMixin @@ -36,28 +32,6 @@ def tearDownClass(cls): finally: super(ArrowPythonUDFParityTests, cls).tearDownClass() - def test_named_arguments_negative(self): - @udf("int") - def test_udf(a, b): - return a + b - - self.spark.udf.register("test_udf", test_udf) - - with self.assertRaisesRegex( - AnalysisException, - "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", - ): - self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM range(2)").show() - - with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"): - self.spark.sql("SELECT test_udf(a => id, id * 10) FROM range(2)").show() - - with self.assertRaises(PythonException): - self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show() - - with self.assertRaises(PythonException): - self.spark.sql("SELECT test_udf(id, a => id * 10) FROM range(2)").show() - if __name__ == "__main__": import unittest diff --git a/python/pyspark/sql/tests/test_arrow_python_udf.py b/python/pyspark/sql/tests/test_arrow_python_udf.py index 23f302ec3c8d3..5a66d61cb66a2 100644 --- a/python/pyspark/sql/tests/test_arrow_python_udf.py +++ b/python/pyspark/sql/tests/test_arrow_python_udf.py @@ -17,7 +17,7 @@ import unittest -from pyspark.errors import PythonException, PySparkNotImplementedError +from pyspark.errors import AnalysisException, PythonException, PySparkNotImplementedError from pyspark.sql import Row from pyspark.sql.functions import udf from pyspark.sql.tests.test_udf import BaseUDFTestsMixin @@ -197,6 +197,28 @@ def test_warn_no_args(self): " without arguments.", ) + def test_named_arguments_negative(self): + @udf("int") + def test_udf(a, b): + return a + b + + self.spark.udf.register("test_udf", test_udf) + + with self.assertRaisesRegex( + AnalysisException, + "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + ): + self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM range(2)").show() + + with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"): + self.spark.sql("SELECT test_udf(a => id, id * 10) FROM range(2)").show() + + with self.assertRaises(PythonException): + self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show() + + with self.assertRaises(PythonException): + self.spark.sql("SELECT test_udf(id, a => id * 10) FROM range(2)").show() + class PythonUDFArrowTests(PythonUDFArrowTestsMixin, ReusedSQLTestCase): @classmethod From acc37531deb9a01555c7ce691aab2629f42c25b0 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Mon, 13 May 2024 13:43:27 +0800 Subject: [PATCH 63/65] [SPARK-47993][PYTHON][FOLLOW-UP] Update migration guide about Python 3.8 dropped ### What changes were proposed in this pull request? This PR is a followup of https://github.com/apache/spark/pull/46228 that updates migration guide about Python 3.8 being dropped. ### Why are the changes needed? To guide end users about the migration to Spark 4.0. ### Does this PR introduce _any_ user-facing change? Yes, it fixes the documentation. ### How was this patch tested? CI in this PR. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46545 from HyukjinKwon/SPARK-47993-followup. Authored-by: Hyukjin Kwon Signed-off-by: yangjie01 --- python/docs/source/migration_guide/pyspark_upgrade.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/python/docs/source/migration_guide/pyspark_upgrade.rst b/python/docs/source/migration_guide/pyspark_upgrade.rst index 26fc634307879..0f252519e7daf 100644 --- a/python/docs/source/migration_guide/pyspark_upgrade.rst +++ b/python/docs/source/migration_guide/pyspark_upgrade.rst @@ -22,6 +22,7 @@ Upgrading PySpark Upgrading from PySpark 3.5 to 4.0 --------------------------------- +* In Spark 4.0, Python 3.8 support was dropped in PySpark. * In Spark 4.0, the minimum supported version for Pandas has been raised from 1.0.5 to 2.0.0 in PySpark. * In Spark 4.0, the minimum supported version for Numpy has been raised from 1.15 to 1.21 in PySpark. * In Spark 4.0, the minimum supported version for PyArrow has been raised from 4.0.0 to 10.0.0 in PySpark. From 13b0d1aab36740293814ce54e38cb4d86f8b762d Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Mon, 13 May 2024 17:14:54 +0900 Subject: [PATCH 64/65] [SPARK-48250][PYTHON][CONNECT][TESTS] Enable array inference tests at test_parity_types.py ### What changes were proposed in this pull request? This PR proposes to enable some array inference tests at test_parity_types.py ### Why are the changes needed? For better test coverage for Spark Connect. ### Does this PR introduce _any_ user-facing change? No, test-only. ### How was this patch tested? CI in this PR should verify them. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46550 from HyukjinKwon/SPARK-48250. Lead-authored-by: Hyukjin Kwon Co-authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../pyspark/sql/tests/connect/test_parity_types.py | 8 ++------ python/pyspark/sql/tests/test_types.py | 14 +++++++------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/tests/connect/test_parity_types.py b/python/pyspark/sql/tests/connect/test_parity_types.py index 82a677574b455..55acb4b1a381b 100644 --- a/python/pyspark/sql/tests/connect/test_parity_types.py +++ b/python/pyspark/sql/tests/connect/test_parity_types.py @@ -39,12 +39,8 @@ def test_create_dataframe_schema_mismatch(self): super().test_create_dataframe_schema_mismatch() @unittest.skip("Spark Connect does not support RDD but the tests depend on them.") - def test_infer_array_element_type_empty(self): - super().test_infer_array_element_type_empty() - - @unittest.skip("Spark Connect does not support RDD but the tests depend on them.") - def test_infer_array_element_type_with_struct(self): - super().test_infer_array_element_type_with_struct() + def test_infer_array_element_type_empty_rdd(self): + super().test_infer_array_element_type_empty_rdd() @unittest.skip("Spark Connect does not support RDD but the tests depend on them.") def test_infer_array_merge_element_types_with_rdd(self): diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 40eded6a4433c..9c931d861c18d 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -366,7 +366,7 @@ def test_infer_array_merge_element_types_with_rdd(self): df = self.spark.createDataFrame(rdd) self.assertEqual(Row(f1=[1, None], f2=[None, 2]), df.first()) - def test_infer_array_element_type_empty(self): + def test_infer_array_element_type_empty_rdd(self): # SPARK-39168: Test inferring array element type from all rows ArrayRow = Row("f1") @@ -379,6 +379,12 @@ def test_infer_array_element_type_empty(self): self.assertEqual(Row(f1=[None]), rows[1]) self.assertEqual(Row(f1=[1]), rows[2]) + def test_infer_array_element_type_empty(self): + # SPARK-39168: Test inferring array element type from all rows + ArrayRow = Row("f1") + + data = [ArrayRow([]), ArrayRow([None]), ArrayRow([1])] + df = self.spark.createDataFrame(data) rows = df.collect() self.assertEqual(Row(f1=[]), rows[0]) @@ -392,12 +398,6 @@ def test_infer_array_element_type_with_struct(self): with self.sql_conf({"spark.sql.pyspark.inferNestedDictAsStruct.enabled": True}): data = [NestedRow([{"payment": 200.5}, {"name": "A"}])] - nestedRdd = self.sc.parallelize(data) - df = self.spark.createDataFrame(nestedRdd) - self.assertEqual( - Row(f1=[Row(payment=200.5, name=None), Row(payment=None, name="A")]), df.first() - ) - df = self.spark.createDataFrame(data) self.assertEqual( Row(f1=[Row(payment=200.5, name=None), Row(payment=None, name="A")]), df.first() From b2140d0f25d81e64a968df83c5da5089051acaac Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Mon, 13 May 2024 17:15:28 +0900 Subject: [PATCH 65/65] [SPARK-48248][PYTHON] Fix nested array to respect legacy conf of inferArrayTypeFromFirstElement ### What changes were proposed in this pull request? This PR fixes a bug that does not respect `spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled` in nested arrays, introduced by https://github.com/apache/spark/pull/36545. ### Why are the changes needed? To have a way to restore the original behaviour. ### Does this PR introduce _any_ user-facing change? Yes, it fixes the regression when `spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled` is set to `True`. ### How was this patch tested? Unittest added. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46548 from HyukjinKwon/SPARK-48248. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/tests/test_types.py | 7 +++++++ python/pyspark/sql/types.py | 18 ++++++++++++++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 9c931d861c18d..bd99804ec5655 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -1621,6 +1621,13 @@ def test_collated_string(self): StringType("UTF8_BINARY_LCASE"), ) + def test_infer_array_element_type_with_struct(self): + # SPARK-48248: Nested array to respect legacy conf of inferArrayTypeFromFirstElement + with self.sql_conf( + {"spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled": True} + ): + self.assertEqual([[1, None]], self.spark.createDataFrame([[[[1, "a"]]]]).first()[0]) + class DataTypeTests(unittest.TestCase): # regression test for SPARK-6055 diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 41be12620fd56..fbd4987713e26 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1951,13 +1951,27 @@ def _infer_type( if len(obj) > 0: if infer_array_from_first_element: return ArrayType( - _infer_type(obj[0], infer_dict_as_struct, prefer_timestamp_ntz), True + _infer_type( + obj[0], + infer_dict_as_struct, + infer_array_from_first_element, + prefer_timestamp_ntz, + ), + True, ) else: return ArrayType( reduce( _merge_type, - (_infer_type(v, infer_dict_as_struct, prefer_timestamp_ntz) for v in obj), + ( + _infer_type( + v, + infer_dict_as_struct, + infer_array_from_first_element, + prefer_timestamp_ntz, + ) + for v in obj + ), ), True, )