From ed03173f33aad0cf8b18096a4fb2470059410751 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 31 Aug 2023 12:36:35 +0800 Subject: [PATCH 01/35] [SPARK-45024][PYTHON][CONNECT] Filter out some configurations in Session Creation ### What changes were proposed in this pull request? https://github.com/apache/spark/pull/42694 filtered out static configurations in local mode This filter out some configurations in both local mode and non-local mode, since the configurations are actually set after session creation, so configurations like `spark.remote` always take no effect. ### Why are the changes needed? avoid unnecessary RPCs and warnings ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? CI ### Was this patch authored or co-authored using generative AI tooling? no Closes #42741 from zhengruifeng/filter_out_some_config. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- python/pyspark/sql/connect/session.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index c326f94d80cef..934386ce95475 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -180,10 +180,16 @@ def enableHiveSupport(self) -> "SparkSession.Builder": def _apply_options(self, session: "SparkSession") -> None: with self._lock: for k, v in self._options.items(): - try: - session.conf.set(k, v) - except Exception as e: - warnings.warn(str(e)) + # the options are applied after session creation, + # so following options always take no effect + if k not in [ + "spark.remote", + "spark.master", + ]: + try: + session.conf.set(k, v) + except Exception as e: + warnings.warn(str(e)) def create(self) -> "SparkSession": has_channel_builder = self._channel_builder is not None From 9a023c479c6a91a602f96ccabba398223c04b3d1 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 31 Aug 2023 15:21:19 +0900 Subject: [PATCH 02/35] [SPARK-45014][CONNECT] Clean up fileserver when cleaning up files, jars and archives in SparkContext ### What changes were proposed in this pull request? This PR proposes to clean up the files, jars and archives added via Spark Connect sessions. ### Why are the changes needed? In [SPARK-44348](https://issues.apache.org/jira/browse/SPARK-44348), we clean up Spark Context's added files but we don't clean up the ones in fileserver. ### Does this PR introduce _any_ user-facing change? Yes, it will avoid slowly growing memory within the file server. ### How was this patch tested? Manually tested. Also existing tests should not be broken. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #42731 from HyukjinKwon/SPARK-45014. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../artifact/SparkConnectArtifactManager.scala | 10 ++++++---- .../main/scala/org/apache/spark/rpc/RpcEnv.scala | 13 +++++++++++++ .../apache/spark/rpc/netty/NettyStreamManager.scala | 4 ++++ 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala index a2df11eeb5832..fee99532bd55f 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala @@ -208,12 +208,14 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging s"sessionId: ${sessionHolder.sessionId}") // Clean up added files - sessionHolder.session.sparkContext.addedFiles.remove(state.uuid) - sessionHolder.session.sparkContext.addedArchives.remove(state.uuid) - sessionHolder.session.sparkContext.addedJars.remove(state.uuid) + val fileserver = SparkEnv.get.rpcEnv.fileServer + val sparkContext = sessionHolder.session.sparkContext + sparkContext.addedFiles.remove(state.uuid).foreach(_.keys.foreach(fileserver.removeFile)) + sparkContext.addedArchives.remove(state.uuid).foreach(_.keys.foreach(fileserver.removeFile)) + sparkContext.addedJars.remove(state.uuid).foreach(_.keys.foreach(fileserver.removeJar)) // Clean up cached relations - val blockManager = sessionHolder.session.sparkContext.env.blockManager + val blockManager = sparkContext.env.blockManager blockManager.removeCache(sessionHolder.userId, sessionHolder.sessionId) // Clean up artifacts folder diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 2fce2889c0977..2575cffdeb3b5 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -206,6 +206,19 @@ private[spark] trait RpcEnvFileServer { fixedBaseUri } + /** + * Removes a file from this RpcEnv. + * + * @param key Local file to remove. + */ + def removeFile(key: String): Unit + + /** + * Removes a jar to from this RpcEnv. + * + * @param key Local jar to remove. + */ + def removeJar(key: String): Unit } private[spark] case class RpcEnvConfig( diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala index 57243133aba92..9ac14f3483683 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala @@ -43,6 +43,10 @@ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) private val jars = new ConcurrentHashMap[String, File]() private val dirs = new ConcurrentHashMap[String, File]() + override def removeFile(key: String): Unit = files.remove(key) + + override def removeJar(key: String): Unit = jars.remove(key) + override def getChunk(streamId: Long, chunkIndex: Int): ManagedBuffer = { throw new UnsupportedOperationException() } From 9d84369ade670737a4ccda166e452e5208eb8253 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 31 Aug 2023 15:27:41 +0800 Subject: [PATCH 03/35] [SPARK-45018][PYTHON][CONNECT] Add CalendarIntervalType to Python Client ### What changes were proposed in this pull request? Add CalendarIntervalType to Python Client ### Why are the changes needed? for feature parity ### Does this PR introduce _any_ user-facing change? yes before this PR: ``` In [1]: from pyspark.sql import functions as sf In [2]: spark.range(1).select(sf.make_interval(sf.lit(1))).schema --------------------------------------------------------------------------- Exception Traceback (most recent call last) Cell In[2], line 1 ----> 1 spark.range(1).select(sf.make_interval(sf.lit(1))).schema File ~/Dev/spark/python/pyspark/sql/connect/dataframe.py:1687, in DataFrame.schema(self) 1685 if self._session is None: 1686 raise Exception("Cannot analyze without SparkSession.") -> 1687 return self._session.client.schema(query) 1688 else: 1689 raise Exception("Empty plan.") ... Exception: Unsupported data type calendar_interval ``` after this PR: ``` Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /__ / .__/\_,_/_/ /_/\_\ version 4.0.0.dev0 /_/ Using Python version 3.10.11 (main, May 17 2023 14:30:36) Client connected to the Spark Connect server at localhost SparkSession available as 'spark'. In [1]: from pyspark.sql import functions as sf In [2]: spark.range(1).select(sf.make_interval(sf.lit(1))).schema Out[2]: StructType([StructField('make_interval(1, 0, 0, 0, 0, 0, 0)', CalendarIntervalType(), True)]) ``` ### How was this patch tested? added UT ### Was this patch authored or co-authored using generative AI tooling? NO Closes #42743 from zhengruifeng/py_connect_cal_interval. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- python/pyspark/sql/connect/types.py | 5 +++++ python/pyspark/sql/tests/connect/test_parity_types.py | 2 +- python/pyspark/sql/tests/test_types.py | 4 ++++ 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/connect/types.py b/python/pyspark/sql/connect/types.py index 0db2833d2c1aa..cd2311e614eec 100644 --- a/python/pyspark/sql/connect/types.py +++ b/python/pyspark/sql/connect/types.py @@ -33,6 +33,7 @@ TimestampNTZType, DayTimeIntervalType, YearMonthIntervalType, + CalendarIntervalType, MapType, StringType, CharType, @@ -169,6 +170,8 @@ def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType: elif isinstance(data_type, YearMonthIntervalType): ret.year_month_interval.start_field = data_type.startField ret.year_month_interval.end_field = data_type.endField + elif isinstance(data_type, CalendarIntervalType): + ret.calendar_interval.CopyFrom(pb2.DataType.CalendarInterval()) elif isinstance(data_type, StructType): struct = pb2.DataType.Struct() for field in data_type.fields: @@ -265,6 +268,8 @@ def proto_schema_to_pyspark_data_type(schema: pb2.DataType) -> DataType: else None ) return YearMonthIntervalType(startField=start, endField=end) + elif schema.HasField("calendar_interval"): + return CalendarIntervalType() elif schema.HasField("array"): return ArrayType( proto_schema_to_pyspark_data_type(schema.array.element_type), diff --git a/python/pyspark/sql/tests/connect/test_parity_types.py b/python/pyspark/sql/tests/connect/test_parity_types.py index 533506c7d2743..44171fd61a35b 100644 --- a/python/pyspark/sql/tests/connect/test_parity_types.py +++ b/python/pyspark/sql/tests/connect/test_parity_types.py @@ -86,7 +86,7 @@ def test_rdd_with_udt(self): def test_udt(self): super().test_udt() - @unittest.skip("SPARK-45018: should support CalendarIntervalType") + @unittest.skip("SPARK-45026: spark.sql should support datatypes not compatible with arrow") def test_calendar_interval_type(self): super().test_calendar_interval_type() diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index d45c4d7e808de..fb752b93a3308 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -1284,6 +1284,10 @@ def test_calendar_interval_type(self): schema1 = self.spark.sql("SELECT make_interval(100, 11, 1, 1, 12, 30, 01.001001)").schema self.assertEqual(schema1.fields[0].dataType, CalendarIntervalType()) + def test_calendar_interval_type_with_sf(self): + schema1 = self.spark.range(1).select(F.make_interval(F.lit(1))).schema + self.assertEqual(schema1.fields[0].dataType, CalendarIntervalType()) + class DataTypeTests(unittest.TestCase): # regression test for SPARK-6055 From 723a0aa30f9a901140d0f97d580d39db56b0729f Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Thu, 31 Aug 2023 19:42:41 +0800 Subject: [PATCH 04/35] Revert "[SPARK-43646][CONNECT][TESTS] Make both SBT and Maven use `spark-proto` uber jar to test the `connect` module" ### What changes were proposed in this pull request? This reverts commit df63adf734370f5c2d71a348f9d36658718b302c. ### Why are the changes needed? As [reported](https://github.com/apache/spark/pull/42236#issuecomment-1700493815) by MaxGekk , the solution for https://github.com/apache/spark/pull/42236 is not perfect, and it breaks the usability of importing Spark as a Maven project into idea. On the other hand, if `mvn clean test` is executed, test failures will also occur like ``` [ERROR] [Error] /tmp/spark-3.5.0/connector/connect/server/target/generated-test-sources/protobuf/java/org/apache/spark/sql/protobuf/protos/TestProto.java:9:46: error: package org.sparkproject.spark_protobuf.protobuf does not exist ``` Therefore, this pr will revert the change of SPARK-43646, and `from_protobuf messageClassName` and `from_protobuf messageClassName options` in `PlanGenerationTestSuite` will be ignored in a follow-up. At present, it is difficult to make the maven testing of the `spark-protobuf` function in the `connect` module as good as possible. ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? Pass GitHub Actions ### Was this patch authored or co-authored using generative AI tooling? No Closes #42746 from LuciferYang/Revert-SPARK-43646. Authored-by: yangjie01 Signed-off-by: yangjie01 --- .../spark/sql/PlanGenerationTestSuite.scala | 5 +- .../from_protobuf_messageClassName.explain | 2 +- ..._protobuf_messageClassName_options.explain | 2 +- .../from_protobuf_messageClassName.json | 2 +- .../from_protobuf_messageClassName.proto.bin | Bin 131 -> 125 bytes ...rom_protobuf_messageClassName_options.json | 2 +- ...rotobuf_messageClassName_options.proto.bin | Bin 182 -> 174 bytes connector/connect/server/pom.xml | 88 ------------------ .../server/src/test/protobuf/test.proto | 27 ------ project/SparkBuild.scala | 55 +---------- 10 files changed, 8 insertions(+), 175 deletions(-) delete mode 100644 connector/connect/server/src/test/protobuf/test.proto diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index df416ef93d83d..ccd68f75bdab1 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -3236,15 +3236,14 @@ class PlanGenerationTestSuite "connect/common/src/test/resources/protobuf-tests/common.desc" test("from_protobuf messageClassName") { - binary.select( - pbFn.from_protobuf(fn.col("bytes"), "org.apache.spark.sql.protobuf.protos.TestProtoObj")) + binary.select(pbFn.from_protobuf(fn.col("bytes"), classOf[StorageLevel].getName)) } test("from_protobuf messageClassName options") { binary.select( pbFn.from_protobuf( fn.col("bytes"), - "org.apache.spark.sql.protobuf.protos.TestProtoObj", + classOf[StorageLevel].getName, Map("recursive.fields.max.depth" -> "2").asJava)) } diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName.explain index 6f48cb090cde5..e7a1867fe9072 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName.explain @@ -1,2 +1,2 @@ -Project [from_protobuf(bytes#0, org.apache.spark.sql.protobuf.protos.TestProtoObj, None) AS from_protobuf(bytes)#0] +Project [from_protobuf(bytes#0, org.apache.spark.connect.proto.StorageLevel, None) AS from_protobuf(bytes)#0] +- LocalRelation , [id#0L, bytes#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_options.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_options.explain index ba87e4774f1af..c02d829fcac1d 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_options.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_options.explain @@ -1,2 +1,2 @@ -Project [from_protobuf(bytes#0, org.apache.spark.sql.protobuf.protos.TestProtoObj, None, (recursive.fields.max.depth,2)) AS from_protobuf(bytes)#0] +Project [from_protobuf(bytes#0, org.apache.spark.connect.proto.StorageLevel, None, (recursive.fields.max.depth,2)) AS from_protobuf(bytes)#0] +- LocalRelation , [id#0L, bytes#0] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName.json b/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName.json index 6c5891e701654..dc23ac2a117b4 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName.json @@ -20,7 +20,7 @@ } }, { "literal": { - "string": "org.apache.spark.sql.protobuf.protos.TestProtoObj" + "string": "org.apache.spark.connect.proto.StorageLevel" } }] } diff --git a/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName.proto.bin index 9d7aeaf308969f3069e0f06ffcb0bfb21f2be445..cc46234b7476cfc4d3623315a0016f4a8d2f1b16 100644 GIT binary patch delta 92 zcmZo>tYzb35@3`npU9@^=PTvS#hX@?pBrCLlwXpcRGKElDa6jjnp9bmS}df`rJJRl sUzDzwSdf^Uk*Zf*kXV$hmzyHi%Eb{YWYMqEx%f+YA)WiqWs+Wf};GA{G`$}Ax - - - kr.motd.maven - os-maven-plugin - 1.6.2 - - target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes @@ -410,87 +403,6 @@ - - - org.apache.maven.plugins - maven-antrun-plugin - - - process-test-sources - - - - - - - run - - - - - - - default-protoc - - true - - - - - - org.xolstice.maven.plugins - protobuf-maven-plugin - 0.6.1 - - com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier} - src/test/protobuf - - - - - test-compile - - - - - - - - - user-defined-protoc - - ${env.SPARK_PROTOC_EXEC_PATH} - - - - - org.xolstice.maven.plugins - protobuf-maven-plugin - 0.6.1 - - ${spark.protoc.executable.path} - src/test/protobuf - - - - - test-compile - - - - - - - - diff --git a/connector/connect/server/src/test/protobuf/test.proto b/connector/connect/server/src/test/protobuf/test.proto deleted file mode 100644 index 844f89ba81f47..0000000000000 --- a/connector/connect/server/src/test/protobuf/test.proto +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; -package org.apache.spark.sql.protobuf.protos; - -option java_multiple_files = true; -option java_outer_classname = "TestProto"; - -message TestProtoObj { - int64 v1 = 1; - int32 v2 = 2; -} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 978165ba0dabc..2f437eeb75cc1 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -17,7 +17,7 @@ import java.io._ import java.nio.charset.StandardCharsets.UTF_8 -import java.nio.file.{Files, StandardOpenOption} +import java.nio.file.Files import java.util.Locale import scala.io.Source @@ -765,13 +765,7 @@ object SparkConnectCommon { object SparkConnect { import BuildCommons.protoVersion - val rewriteJavaFile = TaskKey[Unit]("rewriteJavaFile", - "Rewrite the generated Java PB files.") - val genPBAndRewriteJavaFile = TaskKey[Unit]("genPBAndRewriteJavaFile", - "Generate Java PB files and overwrite their contents.") - lazy val settings = Seq( - PB.protocVersion := BuildCommons.protoVersion, // For some reason the resolution from the imported Maven build does not work for some // of these dependendencies that we need to shade later on. libraryDependencies ++= { @@ -802,42 +796,6 @@ object SparkConnect { ) }, - // SPARK-43646: The following 3 statements are used to make the connect module use the - // Spark-proto assembly jar when compiling and testing using SBT, which can be keep same - // behavior of Maven testing. - (Test / unmanagedJars) += (LocalProject("protobuf") / assembly).value, - (Test / fullClasspath) := - (Test / fullClasspath).value.filterNot { f => f.toString.contains("spark-protobuf") }, - (Test / fullClasspath) += (LocalProject("protobuf") / assembly).value, - - (Test / PB.protoSources) += (Test / sourceDirectory).value / "resources" / "protobuf", - - (Test / PB.targets) := Seq( - PB.gens.java -> target.value / "generated-test-sources", - ), - - // SPARK-43646: Create a custom task to replace all `com.google.protobuf.` with - // `org.sparkproject.spark_protobuf.protobuf.` in the generated Java PB files. - // This is to generate Java files that can be used to test spark-protobuf functions - // in `ProtoToParsedPlanTestSuite`. - rewriteJavaFile := { - val protobufDir = target.value / "generated-test-sources"/"org"/"apache"/"spark"/"sql"/"protobuf"/"protos" - protobufDir.listFiles().foreach { f => - if (f.getName.endsWith(".java")) { - val contents = Files.readAllLines(f.toPath, UTF_8) - val replaced = contents.asScala.map { line => - line.replaceAll("com.google.protobuf.", "org.sparkproject.spark_protobuf.protobuf.") - } - Files.write(f.toPath, replaced.asJava, StandardOpenOption.TRUNCATE_EXISTING, StandardOpenOption.WRITE) - } - } - }, - // SPARK-43646: `genPBAndRewriteJavaFile` is used to specify the execution order of `PB.generate` - // and `rewriteJavaFile`, and makes `Test / compile` dependent on `genPBAndRewriteJavaFile` - // being executed first. - genPBAndRewriteJavaFile := Def.sequential(Test / PB.generate, rewriteJavaFile).value, - (Test / compile) := (Test / compile).dependsOn(genPBAndRewriteJavaFile).value, - (assembly / test) := { }, (assembly / logLevel) := Level.Info, @@ -883,16 +841,7 @@ object SparkConnect { case m if m.toLowerCase(Locale.ROOT).endsWith(".proto") => MergeStrategy.discard case _ => MergeStrategy.first } - ) ++ { - val sparkProtocExecPath = sys.props.get("spark.protoc.executable.path") - if (sparkProtocExecPath.isDefined) { - Seq( - PB.protocExecutable := file(sparkProtocExecPath.get) - ) - } else { - Seq.empty - } - } + ) } object SparkConnectClient { From e72ce91250a9a2c40fd5ed55a50dbc46e4e7e46d Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Thu, 31 Aug 2023 22:50:21 +0300 Subject: [PATCH 05/35] [SPARK-44987][SQL] Assign a name to the error class `_LEGACY_ERROR_TEMP_1100` ### What changes were proposed in this pull request? In the PR, I propose to assign the name `NON_FOLDABLE_ARGUMENT` to the legacy error class `_LEGACY_ERROR_TEMP_1100`, and improve the error message format: make it less restrictive. ### Why are the changes needed? 1. To don't confuse users by slightly restrictive error message about literals. 2. To assign proper name as a part of activity in SPARK-37935 ### Does this PR introduce _any_ user-facing change? No. Only if user's code depends on error class name and message parameters. ### How was this patch tested? By running the modified and affected tests: ``` $ build/sbt "test:testOnly *.StringFunctionsSuite" $ PYSPARK_PYTHON=python3 build/sbt "sql/testOnly org.apache.spark.sql.SQLQueryTestSuite" $ build/sbt "core/testOnly *SparkThrowableSuite" ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #42737 from MaxGekk/assign-name-_LEGACY_ERROR_TEMP_1100. Authored-by: Max Gekk Signed-off-by: Max Gekk --- .../main/resources/error/error-classes.json | 11 +++--- docs/sql-error-conditions.md | 6 ++++ .../expressions/datetimeExpressions.scala | 2 +- .../expressions/mathExpressions.scala | 4 +-- .../expressions/numberFormatExpressions.scala | 2 +- .../sql/errors/QueryCompilationErrors.scala | 14 ++++---- .../ceil-floor-with-scale-param.sql.out | 36 ++++++++++--------- .../analyzer-results/extract.sql.out | 18 +++++----- .../ceil-floor-with-scale-param.sql.out | 36 ++++++++++--------- .../sql-tests/results/extract.sql.out | 18 +++++----- .../spark/sql/StringFunctionsSuite.scala | 8 ++--- 11 files changed, 88 insertions(+), 67 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 3b537cc3d9fc6..af78dd2f9f801 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -2215,6 +2215,12 @@ ], "sqlState" : "42607" }, + "NON_FOLDABLE_ARGUMENT" : { + "message" : [ + "The function requires the parameter to be a foldable expression of the type , but the actual argument is a non-foldable." + ], + "sqlState" : "22024" + }, "NON_LAST_MATCHED_CLAUSE_OMIT_CONDITION" : { "message" : [ "When there are more than one MATCHED clauses in a MERGE statement, only the last MATCHED clause can omit the condition." @@ -4029,11 +4035,6 @@ "() doesn't support the mode. Acceptable modes are and ." ] }, - "_LEGACY_ERROR_TEMP_1100" : { - "message" : [ - "The '' parameter of function '' needs to be a literal." - ] - }, "_LEGACY_ERROR_TEMP_1103" : { "message" : [ "Unsupported component type in arrays." diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index 89c27f72ea093..33072f6c44066 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -1305,6 +1305,12 @@ Cannot call function `` because named argument references are not It is not allowed to use an aggregate function in the argument of another aggregate function. Please use the inner aggregate function in a sub-query. +### NON_FOLDABLE_ARGUMENT + +[SQLSTATE: 22024](sql-error-conditions-sqlstates.html#class-22-data-exception) + +The function `` requires the parameter `` to be a foldable expression of the type ``, but the actual argument is a non-foldable. + ### NON_LAST_MATCHED_CLAUSE_OMIT_CONDITION [SQLSTATE: 42613](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 51ddf2b85f8c2..30a6bec1868ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -2934,7 +2934,7 @@ object Extract { } } } else { - throw QueryCompilationErrors.requireLiteralParameter(funcName, "field", "string") + throw QueryCompilationErrors.nonFoldableArgumentError(funcName, "field", StringType) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index add59a38b7201..89f354db5a97c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -283,10 +283,10 @@ trait CeilFloorExpressionBuilderBase extends ExpressionBuilder { } else if (numArgs == 2) { val scale = expressions(1) if (!(scale.foldable && scale.dataType == IntegerType)) { - throw QueryCompilationErrors.requireLiteralParameter(funcName, "scale", "int") + throw QueryCompilationErrors.nonFoldableArgumentError(funcName, "scale", IntegerType) } if (scale.eval() == null) { - throw QueryCompilationErrors.requireLiteralParameter(funcName, "scale", "int") + throw QueryCompilationErrors.nonFoldableArgumentError(funcName, "scale", IntegerType) } buildWithTwoParams(expressions(0), scale) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala index 7875ed8fe20fe..38abcc41cbff0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala @@ -247,7 +247,7 @@ object ToCharacterBuilder extends ExpressionBuilder { case _: DatetimeType => DateFormatClass(inputExpr, format) case _: BinaryType => if (!(format.dataType == StringType && format.foldable)) { - throw QueryCompilationErrors.requireLiteralParameter(funcName, "format", "string") + throw QueryCompilationErrors.nonFoldableArgumentError(funcName, "format", StringType) } format.eval().asInstanceOf[UTF8String].toString.toLowerCase(Locale.ROOT).trim match { case "base64" => Base64(inputExpr) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index e579e5cf565b2..a97abf8943406 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -1207,14 +1207,16 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat "failFastMode" -> FailFastMode.name)) } - def requireLiteralParameter( - funcName: String, argName: String, requiredType: String): Throwable = { + def nonFoldableArgumentError( + funcName: String, + paramName: String, + paramType: DataType): Throwable = { new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1100", + errorClass = "NON_FOLDABLE_ARGUMENT", messageParameters = Map( - "argName" -> argName, - "funcName" -> funcName, - "requiredType" -> requiredType)) + "funcName" -> toSQLId(funcName), + "paramName" -> toSQLId(paramName), + "paramType" -> toSQLType(paramType))) } def literalTypeUnsupportedForSourceTypeError(field: String, source: Expression): Throwable = { diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ceil-floor-with-scale-param.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ceil-floor-with-scale-param.sql.out index c76b2e5284a42..950584caa8160 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/ceil-floor-with-scale-param.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ceil-floor-with-scale-param.sql.out @@ -81,11 +81,12 @@ SELECT CEIL(2.5, null) -- !query analysis org.apache.spark.sql.AnalysisException { - "errorClass" : "_LEGACY_ERROR_TEMP_1100", + "errorClass" : "NON_FOLDABLE_ARGUMENT", + "sqlState" : "22024", "messageParameters" : { - "argName" : "scale", - "funcName" : "ceil", - "requiredType" : "int" + "funcName" : "`ceil`", + "paramName" : "`scale`", + "paramType" : "\"INT\"" }, "queryContext" : [ { "objectType" : "", @@ -102,11 +103,12 @@ SELECT CEIL(2.5, 'a') -- !query analysis org.apache.spark.sql.AnalysisException { - "errorClass" : "_LEGACY_ERROR_TEMP_1100", + "errorClass" : "NON_FOLDABLE_ARGUMENT", + "sqlState" : "22024", "messageParameters" : { - "argName" : "scale", - "funcName" : "ceil", - "requiredType" : "int" + "funcName" : "`ceil`", + "paramName" : "`scale`", + "paramType" : "\"INT\"" }, "queryContext" : [ { "objectType" : "", @@ -223,11 +225,12 @@ SELECT FLOOR(2.5, null) -- !query analysis org.apache.spark.sql.AnalysisException { - "errorClass" : "_LEGACY_ERROR_TEMP_1100", + "errorClass" : "NON_FOLDABLE_ARGUMENT", + "sqlState" : "22024", "messageParameters" : { - "argName" : "scale", - "funcName" : "floor", - "requiredType" : "int" + "funcName" : "`floor`", + "paramName" : "`scale`", + "paramType" : "\"INT\"" }, "queryContext" : [ { "objectType" : "", @@ -244,11 +247,12 @@ SELECT FLOOR(2.5, 'a') -- !query analysis org.apache.spark.sql.AnalysisException { - "errorClass" : "_LEGACY_ERROR_TEMP_1100", + "errorClass" : "NON_FOLDABLE_ARGUMENT", + "sqlState" : "22024", "messageParameters" : { - "argName" : "scale", - "funcName" : "floor", - "requiredType" : "int" + "funcName" : "`floor`", + "paramName" : "`scale`", + "paramType" : "\"INT\"" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/extract.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/extract.sql.out index 6085457deaa06..eabe92ab12de4 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/extract.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/extract.sql.out @@ -932,11 +932,12 @@ select date_part(c, c) from t -- !query analysis org.apache.spark.sql.AnalysisException { - "errorClass" : "_LEGACY_ERROR_TEMP_1100", + "errorClass" : "NON_FOLDABLE_ARGUMENT", + "sqlState" : "22024", "messageParameters" : { - "argName" : "field", - "funcName" : "date_part", - "requiredType" : "string" + "funcName" : "`date_part`", + "paramName" : "`field`", + "paramType" : "\"STRING\"" }, "queryContext" : [ { "objectType" : "", @@ -964,11 +965,12 @@ select date_part(i, i) from t -- !query analysis org.apache.spark.sql.AnalysisException { - "errorClass" : "_LEGACY_ERROR_TEMP_1100", + "errorClass" : "NON_FOLDABLE_ARGUMENT", + "sqlState" : "22024", "messageParameters" : { - "argName" : "field", - "funcName" : "date_part", - "requiredType" : "string" + "funcName" : "`date_part`", + "paramName" : "`field`", + "paramType" : "\"STRING\"" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/results/ceil-floor-with-scale-param.sql.out b/sql/core/src/test/resources/sql-tests/results/ceil-floor-with-scale-param.sql.out index d55e665a2a1ea..b15682b0a512b 100644 --- a/sql/core/src/test/resources/sql-tests/results/ceil-floor-with-scale-param.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ceil-floor-with-scale-param.sql.out @@ -94,11 +94,12 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "_LEGACY_ERROR_TEMP_1100", + "errorClass" : "NON_FOLDABLE_ARGUMENT", + "sqlState" : "22024", "messageParameters" : { - "argName" : "scale", - "funcName" : "ceil", - "requiredType" : "int" + "funcName" : "`ceil`", + "paramName" : "`scale`", + "paramType" : "\"INT\"" }, "queryContext" : [ { "objectType" : "", @@ -117,11 +118,12 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "_LEGACY_ERROR_TEMP_1100", + "errorClass" : "NON_FOLDABLE_ARGUMENT", + "sqlState" : "22024", "messageParameters" : { - "argName" : "scale", - "funcName" : "ceil", - "requiredType" : "int" + "funcName" : "`ceil`", + "paramName" : "`scale`", + "paramType" : "\"INT\"" }, "queryContext" : [ { "objectType" : "", @@ -253,11 +255,12 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "_LEGACY_ERROR_TEMP_1100", + "errorClass" : "NON_FOLDABLE_ARGUMENT", + "sqlState" : "22024", "messageParameters" : { - "argName" : "scale", - "funcName" : "floor", - "requiredType" : "int" + "funcName" : "`floor`", + "paramName" : "`scale`", + "paramType" : "\"INT\"" }, "queryContext" : [ { "objectType" : "", @@ -276,11 +279,12 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "_LEGACY_ERROR_TEMP_1100", + "errorClass" : "NON_FOLDABLE_ARGUMENT", + "sqlState" : "22024", "messageParameters" : { - "argName" : "scale", - "funcName" : "floor", - "requiredType" : "int" + "funcName" : "`floor`", + "paramName" : "`scale`", + "paramType" : "\"INT\"" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/results/extract.sql.out b/sql/core/src/test/resources/sql-tests/results/extract.sql.out index cc6e8bcb36cd7..8416327ef3154 100644 --- a/sql/core/src/test/resources/sql-tests/results/extract.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/extract.sql.out @@ -714,11 +714,12 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "_LEGACY_ERROR_TEMP_1100", + "errorClass" : "NON_FOLDABLE_ARGUMENT", + "sqlState" : "22024", "messageParameters" : { - "argName" : "field", - "funcName" : "date_part", - "requiredType" : "string" + "funcName" : "`date_part`", + "paramName" : "`field`", + "paramType" : "\"STRING\"" }, "queryContext" : [ { "objectType" : "", @@ -745,11 +746,12 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "_LEGACY_ERROR_TEMP_1100", + "errorClass" : "NON_FOLDABLE_ARGUMENT", + "sqlState" : "22024", "messageParameters" : { - "argName" : "field", - "funcName" : "date_part", - "requiredType" : "string" + "funcName" : "`date_part`", + "paramName" : "`field`", + "paramType" : "\"STRING\"" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 03b9053c71ab9..c61a62f293fa1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -875,11 +875,11 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df2.select(func(col("input"), col("format"))).collect() }, - errorClass = "_LEGACY_ERROR_TEMP_1100", + errorClass = "NON_FOLDABLE_ARGUMENT", parameters = Map( - "argName" -> "format", - "funcName" -> funcName, - "requiredType" -> "string")) + "funcName" -> s"`$funcName`", + "paramName" -> "`format`", + "paramType" -> "\"STRING\"")) checkError( exception = intercept[AnalysisException] { df2.select(func(col("input"), lit("invalid_format"))).collect() From 63365e7c0f242163e30d7d29690b85e9127d8a11 Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Fri, 1 Sep 2023 09:16:28 +0800 Subject: [PATCH 06/35] [SPARK-44994][PYTHON][DOCS] Refine docstring of DataFrame.filter ### What changes were proposed in this pull request? This PR refines the docstring of `DataFrame.filter` by adding more examples. ### Why are the changes needed? To improve PySpark documentation. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? doctest ### Was this patch authored or co-authored using generative AI tooling? No Closes #42708 from allisonwang-db/spark-44994-refine-filter. Authored-by: allisonwang-db Signed-off-by: Ruifeng Zheng --- python/pyspark/sql/dataframe.py | 143 +++++++++++++++++++++++++++----- 1 file changed, 120 insertions(+), 23 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 1d48e14b42013..8417d445eea87 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -3361,48 +3361,145 @@ def filter(self, condition: "ColumnOrName") -> "DataFrame": Parameters ---------- condition : :class:`Column` or str - a :class:`Column` of :class:`types.BooleanType` + A :class:`Column` of :class:`types.BooleanType` or a string of SQL expressions. Returns ------- :class:`DataFrame` - Filtered DataFrame. + A new DataFrame with rows that satisfy the condition. Examples -------- >>> df = spark.createDataFrame([ - ... (2, "Alice"), (5, "Bob")], schema=["age", "name"]) + ... (2, "Alice", "Math"), (5, "Bob", "Physics"), (7, "Charlie", "Chemistry")], + ... schema=["age", "name", "subject"]) Filter by :class:`Column` instances. >>> df.filter(df.age > 3).show() - +---+----+ - |age|name| - +---+----+ - | 5| Bob| - +---+----+ + +---+-------+---------+ + |age| name| subject| + +---+-------+---------+ + | 5| Bob| Physics| + | 7|Charlie|Chemistry| + +---+-------+---------+ >>> df.where(df.age == 2).show() - +---+-----+ - |age| name| - +---+-----+ - | 2|Alice| - +---+-----+ + +---+-----+-------+ + |age| name|subject| + +---+-----+-------+ + | 2|Alice| Math| + +---+-----+-------+ Filter by SQL expression in a string. >>> df.filter("age > 3").show() - +---+----+ - |age|name| - +---+----+ - | 5| Bob| - +---+----+ + +---+-------+---------+ + |age| name| subject| + +---+-------+---------+ + | 5| Bob| Physics| + | 7|Charlie|Chemistry| + +---+-------+---------+ >>> df.where("age = 2").show() - +---+-----+ - |age| name| - +---+-----+ - | 2|Alice| - +---+-----+ + +---+-----+-------+ + |age| name|subject| + +---+-----+-------+ + | 2|Alice| Math| + +---+-----+-------+ + + Filter by multiple conditions. + + >>> df.filter((df.age > 3) & (df.subject == "Physics")).show() + +---+----+-------+ + |age|name|subject| + +---+----+-------+ + | 5| Bob|Physics| + +---+----+-------+ + >>> df.filter((df.age == 2) | (df.subject == "Chemistry")).show() + +---+-------+---------+ + |age| name| subject| + +---+-------+---------+ + | 2| Alice| Math| + | 7|Charlie|Chemistry| + +---+-------+---------+ + + Filter by multiple conditions using SQL expression. + + >>> df.filter("age > 3 AND name = 'Bob'").show() + +---+----+-------+ + |age|name|subject| + +---+----+-------+ + | 5| Bob|Physics| + +---+----+-------+ + + Filter using the :func:`Column.isin` function. + + >>> df.filter(df.name.isin("Alice", "Bob")).show() + +---+-----+-------+ + |age| name|subject| + +---+-----+-------+ + | 2|Alice| Math| + | 5| Bob|Physics| + +---+-----+-------+ + + Filter by a list of values using the :func:`Column.isin` function. + + >>> df.filter(df.subject.isin(["Math", "Physics"])).show() + +---+-----+-------+ + |age| name|subject| + +---+-----+-------+ + | 2|Alice| Math| + | 5| Bob|Physics| + +---+-----+-------+ + + Filter using the `~` operator to exclude certain values. + + >>> df.filter(~df.name.isin(["Alice", "Charlie"])).show() + +---+----+-------+ + |age|name|subject| + +---+----+-------+ + | 5| Bob|Physics| + +---+----+-------+ + + Filter using the :func:`Column.isNotNull` function. + + >>> df.filter(df.name.isNotNull()).show() + +---+-------+---------+ + |age| name| subject| + +---+-------+---------+ + | 2| Alice| Math| + | 5| Bob| Physics| + | 7|Charlie|Chemistry| + +---+-------+---------+ + + Filter using the :func:`Column.like` function. + + >>> df.filter(df.name.like("Al%")).show() + +---+-----+-------+ + |age| name|subject| + +---+-----+-------+ + | 2|Alice| Math| + +---+-----+-------+ + + Filter using the :func:`Column.contains` function. + + >>> df.filter(df.name.contains("i")).show() + +---+-------+---------+ + |age| name| subject| + +---+-------+---------+ + | 2| Alice| Math| + | 7|Charlie|Chemistry| + +---+-------+---------+ + + Filter using the :func:`Column.between` function. + + >>> df.filter(df.age.between(2, 5)).show() + +---+-----+-------+ + |age| name|subject| + +---+-----+-------+ + | 2|Alice| Math| + | 5| Bob|Physics| + +---+-----+-------+ """ if isinstance(condition, str): jdf = self._jdf.filter(condition) From 4bb3aaeb4c3f13a723b6da30fe07c007e417b98c Mon Sep 17 00:00:00 2001 From: panbingkun Date: Fri, 1 Sep 2023 09:18:22 +0800 Subject: [PATCH 07/35] [SPARK-45028][PYTHON][DOCS] Refine docstring of `DataFrame.drop` ### What changes were proposed in this pull request? This pr aims to refine docstring of `DataFrame.drop`. ### Why are the changes needed? To improve PySpark documentation. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #42748 from panbingkun/SPARK-45028. Authored-by: panbingkun Signed-off-by: Ruifeng Zheng --- python/pyspark/sql/dataframe.py | 45 ++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 8417d445eea87..42d85b82e9e21 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -5513,7 +5513,8 @@ def drop(self, *cols: str) -> "DataFrame": ... def drop(self, *cols: "ColumnOrName") -> "DataFrame": # type: ignore[misc] - """Returns a new :class:`DataFrame` without specified columns. + """ + Returns a new :class:`DataFrame` without specified columns. This is a no-op if the schema doesn't contain the given column name(s). .. versionadded:: 1.4.0 @@ -5524,28 +5525,26 @@ def drop(self, *cols: "ColumnOrName") -> "DataFrame": # type: ignore[misc] Parameters ---------- cols: str or :class:`Column` - a name of the column, or the :class:`Column` to drop + A name of the column, or the :class:`Column` to be dropped. Returns ------- :class:`DataFrame` - DataFrame without given columns. + A new :class:`DataFrame` without the specified columns. Notes ----- - When an input is a column name, it is treated literally without further interpretation. - Otherwise, will try to match the equivalent expression. - So that dropping column by its name `drop(colName)` has different semantic with directly - dropping the column `drop(col(colName))`. + - When an input is a column name, it is treated literally without further interpretation. + Otherwise, it will try to match the equivalent expression. + So dropping a column by its name `drop(colName)` has a different semantic + with directly dropping the column `drop(col(colName))`. Examples -------- - >>> from pyspark.sql import Row - >>> from pyspark.sql.functions import col, lit + Example 1: Drop a column by name. + >>> df = spark.createDataFrame( ... [(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) - >>> df2 = spark.createDataFrame([Row(height=80, name="Tom"), Row(height=85, name="Bob")]) - >>> df.drop('age').show() +-----+ | name| @@ -5554,6 +5553,9 @@ def drop(self, *cols: "ColumnOrName") -> "DataFrame": # type: ignore[misc] |Alice| | Bob| +-----+ + + Example 2: Drop a column by :class:`Column` object. + >>> df.drop(df.age).show() +-----+ | name| @@ -5563,9 +5565,10 @@ def drop(self, *cols: "ColumnOrName") -> "DataFrame": # type: ignore[misc] | Bob| +-----+ - Drop the column that joined both DataFrames on. + Example 3: Drop the column that joined both DataFrames on. - >>> df.join(df2, df.name == df2.name, 'inner').drop('name').sort('age').show() + >>> df2 = spark.createDataFrame([(80, "Tom"), (85, "Bob")], ["height", "name"]) + >>> df.join(df2, df.name == df2.name).drop('name').sort('age').show() +---+------+ |age|height| +---+------+ @@ -5586,7 +5589,7 @@ def drop(self, *cols: "ColumnOrName") -> "DataFrame": # type: ignore[misc] | 16| Bob| 85| Bob| +---+-----+------+----+ - Drop two column by the same name. + Example 4: Drop two column by the same name. >>> df3.drop("name").show() +---+------+ @@ -5600,14 +5603,18 @@ def drop(self, *cols: "ColumnOrName") -> "DataFrame": # type: ignore[misc] | 16| 85| +---+------+ - Can not drop col('name') due to ambiguous reference. + Example 5: Can not drop col('name') due to ambiguous reference. - >>> df3.drop(col("name")).show() + >>> from pyspark.sql import functions as sf + >>> df3.drop(sf.col("name")).show() Traceback (most recent call last): ... pyspark.errors.exceptions.captured.AnalysisException: [AMBIGUOUS_REFERENCE] Reference... - >>> df4 = df.withColumn("a.b.c", lit(1)) + Example 6: Can not find a column matching the expression "a.b.c". + + >>> from pyspark.sql import functions as sf + >>> df4 = df.withColumn("a.b.c", sf.lit(1)) >>> df4.show() +---+-----+-----+ |age| name|a.b.c| @@ -5626,9 +5633,7 @@ def drop(self, *cols: "ColumnOrName") -> "DataFrame": # type: ignore[misc] | 16| Bob| +---+-----+ - Can not find a column matching the expression "a.b.c". - - >>> df4.drop(col("a.b.c")).show() + >>> df4.drop(sf.col("a.b.c")).show() +---+-----+-----+ |age| name|a.b.c| +---+-----+-----+ From c783e5abb58a04e1151ded7553dfe6ebdcc0c15e Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Fri, 1 Sep 2023 09:25:02 +0800 Subject: [PATCH 08/35] [SPARK-44990][SQL][FOLLOWUP] Remove lazy of `nullAsQuotedEmptyString` ### What changes were proposed in this pull request? This is a follow up PR of #42738 . Remove lazy of `nullAsQuotedEmptyString`. ### Why are the changes needed? Remove lazy of nullAsQuotedEmptyString. Because most case it useless. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? add new test. ### Was this patch authored or co-authored using generative AI tooling? No Closes #42744 from Hisoka-X/SPARK-44990_csv_null_benchmark. Authored-by: Jia Fan Signed-off-by: Wenchen Fan --- .../org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala index c8eded3ccd4d5..0ed89f8cba2d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala @@ -60,7 +60,7 @@ class UnivocityGenerator( options.locale, legacyFormat = FAST_DATE_FORMAT, isParsing = false) - private lazy val nullAsQuotedEmptyString = + private val nullAsQuotedEmptyString = SQLConf.get.getConf(SQLConf.LEGACY_NULL_VALUE_WRITTEN_AS_QUOTED_EMPTY_STRING_CSV) @scala.annotation.tailrec From 10f31904636da727e439d5f1792c3ab7e8e1d24b Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Fri, 1 Sep 2023 10:52:44 +0800 Subject: [PATCH 09/35] [SPARK-45029][CONNECT][TESTS] Ignore `from_protobuf messageClassName/from_protobuf messageClassName options` in `PlanGenerationTestSuite` ### What changes were proposed in this pull request? This pr aims ignore `from_protobuf messageClassName` and `from_protobuf messageClassName options` in `PlanGenerationTestSuite` and remove the related golden files, after this change `from_protobuf_messageClassName` and `from_protobuf_messageClassName_options` in `ProtoToParsedPlanTestSuite` be ignored too. ### Why are the changes needed? SPARK-43646 | (https://github.com/apache/spark/pull/42236) makes both Maven and SBT use the shaded `spark-protobuf` module when testing the connect module, this allows `mvn clean install` and `mvn package test` to successfully pass tests. But if `mvn clean test` is executed directly, an error `package org.sparkproject.spark_protobuf.protobuf does not exist` will occur. This is because `mvn clean test` directly uses the classes file of the `spark-protobuf` module for testing, without the 'package', hence it does not `shade` and `relocate` protobuf. On the other hand, the change of SPARK-43646 breaks the usability of importing Spark as a Maven project into IDEA(https://github.com/apache/spark/pull/42236#issuecomment-1700493815). So https://github.com/apache/spark/pull/42746 revert the change of [SPARK-43646](https://issues.apache.org/jira/browse/SPARK-43646). It's difficult to find a perfect solution to solve this maven test issues now, as in certain scenarios tests would use the `shaded spark-protobuf jar`, like `mvn package test`, while in some other scenarios it will use the `unshaded classes directory`, such as `mvn clean test`. so this pr ignores the relevant tests first and leaves a TODO(SPARK-45030), to re-enable these tests when we find a better solution. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass Github Actions ### Was this patch authored or co-authored using generative AI tooling? No Closes #42751 from LuciferYang/SPARK-45029. Authored-by: yangjie01 Signed-off-by: yangjie01 --- .../spark/sql/PlanGenerationTestSuite.scala | 8 +++- .../from_protobuf_messageClassName.explain | 2 - ..._protobuf_messageClassName_options.explain | 2 - .../from_protobuf_messageClassName.json | 29 ------------ .../from_protobuf_messageClassName.proto.bin | Bin 125 -> 0 bytes ...rom_protobuf_messageClassName_options.json | 42 ------------------ ...rotobuf_messageClassName_options.proto.bin | Bin 174 -> 0 bytes 7 files changed, 6 insertions(+), 77 deletions(-) delete mode 100644 connector/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName.explain delete mode 100644 connector/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_options.explain delete mode 100644 connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName.json delete mode 100644 connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName.proto.bin delete mode 100644 connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName_options.json delete mode 100644 connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName_options.proto.bin diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index ccd68f75bdab1..c457f26921358 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -3235,11 +3235,15 @@ class PlanGenerationTestSuite private val testDescFilePath: String = s"${IntegrationTestUtils.sparkHome}/connector/" + "connect/common/src/test/resources/protobuf-tests/common.desc" - test("from_protobuf messageClassName") { + // TODO(SPARK-45030): Re-enable this test when all Maven test scenarios succeed and there + // are no other negative impacts. For the problem description, please refer to SPARK-45029 + ignore("from_protobuf messageClassName") { binary.select(pbFn.from_protobuf(fn.col("bytes"), classOf[StorageLevel].getName)) } - test("from_protobuf messageClassName options") { + // TODO(SPARK-45030): Re-enable this test when all Maven test scenarios succeed and there + // are no other negative impacts. For the problem description, please refer to SPARK-45029 + ignore("from_protobuf messageClassName options") { binary.select( pbFn.from_protobuf( fn.col("bytes"), diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName.explain deleted file mode 100644 index e7a1867fe9072..0000000000000 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [from_protobuf(bytes#0, org.apache.spark.connect.proto.StorageLevel, None) AS from_protobuf(bytes)#0] -+- LocalRelation , [id#0L, bytes#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_options.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_options.explain deleted file mode 100644 index c02d829fcac1d..0000000000000 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_options.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [from_protobuf(bytes#0, org.apache.spark.connect.proto.StorageLevel, None, (recursive.fields.max.depth,2)) AS from_protobuf(bytes)#0] -+- LocalRelation , [id#0L, bytes#0] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName.json b/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName.json deleted file mode 100644 index dc23ac2a117b4..0000000000000 --- a/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "common": { - "planId": "1" - }, - "project": { - "input": { - "common": { - "planId": "0" - }, - "localRelation": { - "schema": "struct\u003cid:bigint,bytes:binary\u003e" - } - }, - "expressions": [{ - "unresolvedFunction": { - "functionName": "from_protobuf", - "arguments": [{ - "unresolvedAttribute": { - "unparsedIdentifier": "bytes" - } - }, { - "literal": { - "string": "org.apache.spark.connect.proto.StorageLevel" - } - }] - } - }] - } -} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName.proto.bin deleted file mode 100644 index cc46234b7476cfc4d3623315a0016f4a8d2f1b16..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 125 zcmd;L5@3`n=ThTh5@3i@5Rxk{DJo4avB^xaO3F;n%q!7Jsw_z@26FNeiz@A;e5Jg( zc+-mVbK?t&@=NlQO4Ecmh1j`R!K#GxxpcF%^NZ5;5(^TOGg9@63lfX6^^)`R@=}va U^uRjwf=lv?64O(CQp-|v0OO}8)Bpeg diff --git a/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName_options.json b/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName_options.json deleted file mode 100644 index 36f69646ef83d..0000000000000 --- a/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName_options.json +++ /dev/null @@ -1,42 +0,0 @@ -{ - "common": { - "planId": "1" - }, - "project": { - "input": { - "common": { - "planId": "0" - }, - "localRelation": { - "schema": "struct\u003cid:bigint,bytes:binary\u003e" - } - }, - "expressions": [{ - "unresolvedFunction": { - "functionName": "from_protobuf", - "arguments": [{ - "unresolvedAttribute": { - "unparsedIdentifier": "bytes" - } - }, { - "literal": { - "string": "org.apache.spark.connect.proto.StorageLevel" - } - }, { - "unresolvedFunction": { - "functionName": "map", - "arguments": [{ - "literal": { - "string": "recursive.fields.max.depth" - } - }, { - "literal": { - "string": "2" - } - }] - } - }] - } - }] - } -} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName_options.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName_options.proto.bin deleted file mode 100644 index 72a1c6b8207e9bdb15ca23c59df0c962fc6a5eaf..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 174 zcmW;Du?oU45P)H;6!AbP8G{rj7wZX*B0hnen~S9BwN156xTKI3&W989s#O=Y@}<*z{gAD!aYQGVnb<|W(q>evRG_obMLDa3;kTi7&f z%M~i8bWP};;u~t)b)*9h2cCmvRndtabQdTyn6%1?6c&wS(mi|gAS?~t3y-aOVnHs{ JB8Ev5?SAcDHU$6x From b0f9978ec08caa9302c7340951b5c2979315ca13 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 1 Sep 2023 11:12:51 +0800 Subject: [PATCH 10/35] [SPARK-45026][CONNECT] `spark.sql` should support datatypes not compatible with arrow ### What changes were proposed in this pull request? Move the arrow batch creation to the `isCommand` branch ### Why are the changes needed? https://github.com/apache/spark/pull/42736 and https://github.com/apache/spark/pull/42743 introduced the `CalendarIntervalType` in Spark Connect Python Client, however, there is a failure ``` spark.sql("SELECT make_interval(100, 11, 1, 1, 12, 30, 01.001001)") ... pyspark.errors.exceptions.connect.UnsupportedOperationException: [UNSUPPORTED_DATATYPE] Unsupported data type "INTERVAL". ``` The root causes is that `handleSqlCommand` always create an arrow batch while `ArrowUtils` doesn't accept `CalendarIntervalType` now. this PR mainly focus on enabling `schema` with datatypes not compatible with arrow. In the future, we should make `ArrowUtils` accept `CalendarIntervalType` to make `collect/toPandas` works ### Does this PR introduce _any_ user-facing change? yes after this PR ``` In [1]: spark.sql("SELECT make_interval(100, 11, 1, 1, 12, 30, 01.001001)") Out[1]: DataFrame[make_interval(100, 11, 1, 1, 12, 30, 1.001001): interval] In [2]: spark.sql("SELECT make_interval(100, 11, 1, 1, 12, 30, 01.001001)").schema Out[2]: StructType([StructField('make_interval(100, 11, 1, 1, 12, 30, 1.001001)', CalendarIntervalType(), True)]) ``` ### How was this patch tested? enabled ut ### Was this patch authored or co-authored using generative AI tooling? no Closes #42754 from zhengruifeng/connect_sql_types. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- .../connect/planner/SparkConnectPlanner.scala | 40 +++++++++---------- .../sql/tests/connect/test_parity_types.py | 4 -- 2 files changed, 20 insertions(+), 24 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index fbe877b454764..547b6a9fb4039 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -2469,30 +2469,30 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { val maxBatchSize = (SparkEnv.get.conf.get(CONNECT_GRPC_ARROW_MAX_BATCH_SIZE) * 0.7).toLong val timeZoneId = session.sessionState.conf.sessionLocalTimeZone - // Convert the data. - val bytes = if (rows.isEmpty) { - ArrowConverters.createEmptyArrowBatch( - schema, - timeZoneId, - errorOnDuplicatedFieldNames = false) - } else { - val batches = ArrowConverters.toBatchWithSchemaIterator( - rowIter = rows.iterator, - schema = schema, - maxRecordsPerBatch = -1, - maxEstimatedBatchSize = maxBatchSize, - timeZoneId = timeZoneId, - errorOnDuplicatedFieldNames = false) - assert(batches.hasNext) - val bytes = batches.next() - assert(!batches.hasNext, s"remaining batches: ${batches.size}") - bytes - } - // To avoid explicit handling of the result on the client, we build the expected input // of the relation on the server. The client has to simply forward the result. val result = SqlCommandResult.newBuilder() if (isCommand) { + // Convert the data. + val bytes = if (rows.isEmpty) { + ArrowConverters.createEmptyArrowBatch( + schema, + timeZoneId, + errorOnDuplicatedFieldNames = false) + } else { + val batches = ArrowConverters.toBatchWithSchemaIterator( + rowIter = rows.iterator, + schema = schema, + maxRecordsPerBatch = -1, + maxEstimatedBatchSize = maxBatchSize, + timeZoneId = timeZoneId, + errorOnDuplicatedFieldNames = false) + assert(batches.hasNext) + val bytes = batches.next() + assert(!batches.hasNext, s"remaining batches: ${batches.size}") + bytes + } + result.setRelation( proto.Relation .newBuilder() diff --git a/python/pyspark/sql/tests/connect/test_parity_types.py b/python/pyspark/sql/tests/connect/test_parity_types.py index 44171fd61a35b..807c295fae2a1 100644 --- a/python/pyspark/sql/tests/connect/test_parity_types.py +++ b/python/pyspark/sql/tests/connect/test_parity_types.py @@ -86,10 +86,6 @@ def test_rdd_with_udt(self): def test_udt(self): super().test_udt() - @unittest.skip("SPARK-45026: spark.sql should support datatypes not compatible with arrow") - def test_calendar_interval_type(self): - super().test_calendar_interval_type() - if __name__ == "__main__": import unittest From 74447ab5817bcb76f8979ae21bf38a6d193a3e46 Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Thu, 31 Aug 2023 23:16:24 -0500 Subject: [PATCH 11/35] [SPARK-44162][CORE] Support G1GC in spark metrics ### What changes were proposed in this pull request? As a part of support JDK21, add support `G1 Concurrent GC` GarbageCollectorMXBean in Spark metrics. Refer https://github.com/openjdk/jdk/pull/11341 and https://bugs.openjdk.org/browse/JDK-8297247 , the `G1 Concurrent GC` not a part of YoungGC or FullGC. So we follow the JDK definition, bring two new metrics: `ConcurrentGCCount` and `ConcurrentGCTime`. ### Why are the changes needed? add new builtin garbage collectors for metrics. ### Does this PR introduce _any_ user-facing change? Yes, will receive new metrics. ### How was this patch tested? Test in local, I will add test if necessary. Closes #41808 from Hisoka-X/SPARK-44162_G1GC_on_JDK21. Authored-by: Jia Fan Signed-off-by: Mridul Muralidharan gmail.com> --- .../spark/metrics/ExecutorMetricType.scala | 11 +++++-- .../status/api/v1/PrometheusResource.scala | 4 +-- .../complete_stage_list_json_expectation.json | 12 ++++++-- ...xcludeOnFailure_for_stage_expectation.json | 12 ++++++-- ...eOnFailure_node_for_stage_expectation.json | 24 +++++++++++---- .../executor_list_json_expectation.json | 4 ++- ...ith_executor_metrics_json_expectation.json | 16 +++++++--- .../executor_memory_usage_expectation.json | 16 +++++++--- ...tor_node_excludeOnFailure_expectation.json | 16 +++++++--- ...ludeOnFailure_unexcluding_expectation.json | 16 +++++++--- .../failed_stage_list_json_expectation.json | 4 ++- ..._details_with_failed_task_expectation.json | 8 +++-- .../one_stage_attempt_json_expectation.json | 8 +++-- .../one_stage_json_expectation.json | 8 +++-- ...e_stage_json_with_details_expectation.json | 8 +++-- ...age_json_with_partitionId_expectation.json | 8 +++-- .../stage_list_json_expectation.json | 16 +++++++--- ...ist_with_accumulable_json_expectation.json | 4 ++- ...ge_list_with_peak_metrics_expectation.json | 12 ++++++-- ...age_with_accumulable_json_expectation.json | 8 +++-- .../stage_with_peak_metrics_expectation.json | 12 ++++++-- ..._with_speculation_summary_expectation.json | 20 +++++++++---- .../stage_with_summaries_expectation.json | 16 +++++++--- .../apache/spark/util/JsonProtocolSuite.scala | 30 ++++++++++++------- 24 files changed, 217 insertions(+), 76 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/metrics/ExecutorMetricType.scala b/core/src/main/scala/org/apache/spark/metrics/ExecutorMetricType.scala index 648532faa3a1c..1e80eb66dc520 100644 --- a/core/src/main/scala/org/apache/spark/metrics/ExecutorMetricType.scala +++ b/core/src/main/scala/org/apache/spark/metrics/ExecutorMetricType.scala @@ -110,10 +110,12 @@ case object GarbageCollectionMetrics extends ExecutorMetricType with Logging { "MinorGCTime", "MajorGCCount", "MajorGCTime", - "TotalGCTime" + "TotalGCTime", + "ConcurrentGCCount", + "ConcurrentGCTime" ) - /* We builtin some common GC collectors which categorized as young generation and old */ + /* We builtin some common GC collectors */ private[spark] val YOUNG_GENERATION_BUILTIN_GARBAGE_COLLECTORS = Seq( "Copy", "PS Scavenge", @@ -128,6 +130,8 @@ case object GarbageCollectionMetrics extends ExecutorMetricType with Logging { "G1 Old Generation" ) + private[spark] val BUILTIN_CONCURRENT_GARBAGE_COLLECTOR = "G1 Concurrent GC" + private lazy val youngGenerationGarbageCollector: Seq[String] = { SparkEnv.get.conf.get(config.EVENT_LOG_GC_METRICS_YOUNG_GENERATION_GARBAGE_COLLECTORS) } @@ -147,6 +151,9 @@ case object GarbageCollectionMetrics extends ExecutorMetricType with Logging { } else if (oldGenerationGarbageCollector.contains(mxBean.getName)) { gcMetrics(2) = mxBean.getCollectionCount gcMetrics(3) = mxBean.getCollectionTime + } else if (BUILTIN_CONCURRENT_GARBAGE_COLLECTOR.equals(mxBean.getName)) { + gcMetrics(5) = mxBean.getCollectionCount + gcMetrics(6) = mxBean.getCollectionTime } else if (!nonBuiltInCollectors.contains(mxBean.getName)) { nonBuiltInCollectors = mxBean.getName +: nonBuiltInCollectors // log it when first seen diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/PrometheusResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/PrometheusResource.scala index 9658e5e627724..ca088dc80550b 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/PrometheusResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/PrometheusResource.scala @@ -97,10 +97,10 @@ private[v1] class PrometheusResource extends ApiRequestContext { names.foreach { name => sb.append(s"$prefix${name}_bytes$labels ${m.getMetricValue(name)}\n") } - Seq("MinorGCCount", "MajorGCCount").foreach { name => + Seq("MinorGCCount", "MajorGCCount", "ConcurrentGCCount").foreach { name => sb.append(s"$prefix${name}_total$labels ${m.getMetricValue(name)}\n") } - Seq("MinorGCTime", "MajorGCTime").foreach { name => + Seq("MinorGCTime", "MajorGCTime", "ConcurrentGCTime").foreach { name => sb.append(s"$prefix${name}_seconds_total$labels ${m.getMetricValue(name) * 0.001}\n") } } diff --git a/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json index 850c3777ec4d8..ac0f2ce26051d 100644 --- a/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json @@ -76,7 +76,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 } }, { "status" : "COMPLETE", @@ -156,7 +158,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 } }, { "status" : "COMPLETE", @@ -236,6 +240,8 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/excludeOnFailure_for_stage_expectation.json b/core/src/test/resources/HistoryServerExpectations/excludeOnFailure_for_stage_expectation.json index ed4ed9ad87185..d614bb000e4ce 100644 --- a/core/src/test/resources/HistoryServerExpectations/excludeOnFailure_for_stage_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/excludeOnFailure_for_stage_expectation.json @@ -887,7 +887,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isExcludedForStage" : true }, @@ -928,7 +930,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isExcludedForStage" : false } @@ -956,6 +960,8 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 } } diff --git a/core/src/test/resources/HistoryServerExpectations/excludeOnFailure_node_for_stage_expectation.json b/core/src/test/resources/HistoryServerExpectations/excludeOnFailure_node_for_stage_expectation.json index f96a59fae5378..475dee00a2654 100644 --- a/core/src/test/resources/HistoryServerExpectations/excludeOnFailure_node_for_stage_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/excludeOnFailure_node_for_stage_expectation.json @@ -1021,7 +1021,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isExcludedForStage" : true }, @@ -1062,7 +1064,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isExcludedForStage" : true }, @@ -1103,7 +1107,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isExcludedForStage" : false }, @@ -1144,7 +1150,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isExcludedForStage" : false }, @@ -1185,7 +1193,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isExcludedForStage" : true } @@ -1213,6 +1223,8 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 } } diff --git a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json index ec3fc280b0a5e..a860682ca2e24 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json @@ -42,7 +42,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "attributes" : { }, "resources" : { }, diff --git a/core/src/test/resources/HistoryServerExpectations/executor_list_with_executor_metrics_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_list_with_executor_metrics_json_expectation.json index 9b7498d9e9145..2833cdcfde5dd 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_list_with_executor_metrics_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_list_with_executor_metrics_json_expectation.json @@ -48,7 +48,9 @@ "MinorGCTime" : 55, "MajorGCCount" : 3, "MajorGCTime" : 144, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "attributes" : { }, "resources" : { }, @@ -108,7 +110,9 @@ "MinorGCTime" : 145, "MajorGCCount" : 2, "MajorGCTime" : 63, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "attributes" : { "NM_HTTP_ADDRESS" : "test-3.vpc.company.com:8042", @@ -178,7 +182,9 @@ "MinorGCTime" : 106, "MajorGCCount" : 2, "MajorGCTime" : 75, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "attributes" : { "NM_HTTP_ADDRESS" : "test-4.vpc.company.com:8042", @@ -248,7 +254,9 @@ "MinorGCTime" : 140, "MajorGCCount" : 2, "MajorGCTime" : 60, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "attributes" : { "NM_HTTP_ADDRESS" : "test-2.vpc.company.com:8042", diff --git a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json index fbb7b6631f02a..8a96858a2014a 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json @@ -85,7 +85,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "attributes" : { }, "resources" : { }, @@ -145,7 +147,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "attributes" : { }, "resources" : { }, @@ -205,7 +209,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "attributes" : { }, "resources" : { }, @@ -265,7 +271,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "attributes" : { }, "resources" : { }, diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_excludeOnFailure_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_excludeOnFailure_expectation.json index fbb7b6631f02a..8a96858a2014a 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_node_excludeOnFailure_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_node_excludeOnFailure_expectation.json @@ -85,7 +85,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "attributes" : { }, "resources" : { }, @@ -145,7 +147,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "attributes" : { }, "resources" : { }, @@ -205,7 +209,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "attributes" : { }, "resources" : { }, @@ -265,7 +271,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "attributes" : { }, "resources" : { }, diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_excludeOnFailure_unexcluding_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_excludeOnFailure_unexcluding_expectation.json index b72ed0a625420..0e5e73f36fabd 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_node_excludeOnFailure_unexcluding_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_node_excludeOnFailure_unexcluding_expectation.json @@ -73,7 +73,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "attributes" : { }, "resources" : { }, @@ -127,7 +129,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "attributes" : { }, "resources" : { }, @@ -181,7 +185,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "attributes" : { }, "resources" : { }, @@ -235,7 +241,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "attributes" : { }, "resources" : { }, diff --git a/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json index fee7377f18134..dc1bcd6a39625 100644 --- a/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json @@ -75,7 +75,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isShufflePushEnabled" : false, "shuffleMergersCount" : 0 diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_details_with_failed_task_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_details_with_failed_task_expectation.json index 9e390a995c36c..e24ac4f82b8a8 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_details_with_failed_task_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_details_with_failed_task_expectation.json @@ -92,7 +92,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isExcludedForStage" : false } @@ -120,6 +122,8 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 } } diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json index 887d2678e6160..659e3c41d9289 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json @@ -595,7 +595,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isExcludedForStage" : false } @@ -625,6 +627,8 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 } } diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json index 3bb59aaf5b507..f84cf26fcf1d6 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json @@ -595,7 +595,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isExcludedForStage" : false } @@ -625,6 +627,8 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_json_with_details_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_json_with_details_expectation.json index b688b72b04d50..564f3eadd1cc2 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_json_with_details_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_json_with_details_expectation.json @@ -597,7 +597,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isExcludedForStage" : false } @@ -625,6 +627,8 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_json_with_partitionId_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_json_with_partitionId_expectation.json index 83ffb7da8e77f..2bf7f34803775 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_json_with_partitionId_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_json_with_partitionId_expectation.json @@ -721,7 +721,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isExcludedForStage" : false } @@ -749,7 +751,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isShufflePushEnabled" : false, "shuffleMergersCount" : 0 diff --git a/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json index e3cd980943450..8df41bfcc8d78 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json @@ -74,7 +74,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isShufflePushEnabled" : false, "shuffleMergersCount" : 0 @@ -155,7 +157,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isShufflePushEnabled" : false, "shuffleMergersCount" : 0 @@ -235,7 +239,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isShufflePushEnabled" : false, "shuffleMergersCount" : 0 @@ -315,7 +321,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isShufflePushEnabled" : false, "shuffleMergersCount" : 0 diff --git a/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json index e4caffcf10787..730df3fbd5341 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json @@ -78,7 +78,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isShufflePushEnabled" : false, "shuffleMergersCount" : 0 diff --git a/core/src/test/resources/HistoryServerExpectations/stage_list_with_peak_metrics_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_list_with_peak_metrics_expectation.json index d3459be777d48..16d92244ee0d0 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_list_with_peak_metrics_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_list_with_peak_metrics_expectation.json @@ -74,7 +74,9 @@ "MinorGCTime" : 115, "MajorGCCount" : 4, "MajorGCTime" : 339, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isShufflePushEnabled" : false, "shuffleMergersCount" : 0 @@ -155,7 +157,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isShufflePushEnabled" : false, "shuffleMergersCount" : 0 @@ -236,7 +240,9 @@ "MinorGCTime" : 33, "MajorGCCount" : 3, "MajorGCTime" : 110, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isShufflePushEnabled" : false, "shuffleMergersCount" : 0 diff --git a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json index 3880818a7b5df..e38741b7bf6e7 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json @@ -639,7 +639,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isExcludedForStage" : false } @@ -667,7 +669,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isShufflePushEnabled" : false, "shuffleMergersCount" : 0 diff --git a/core/src/test/resources/HistoryServerExpectations/stage_with_peak_metrics_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_with_peak_metrics_expectation.json index d3eb7d55e0e1d..630b0512e8f98 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_with_peak_metrics_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_with_peak_metrics_expectation.json @@ -1147,7 +1147,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isExcludedForStage" : false }, @@ -1188,7 +1190,9 @@ "MinorGCTime" : 115, "MajorGCCount" : 4, "MajorGCTime" : 339, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isExcludedForStage" : false } @@ -1216,7 +1220,9 @@ "MinorGCTime" : 115, "MajorGCCount" : 4, "MajorGCTime" : 339, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isShufflePushEnabled" : false, "shuffleMergersCount" : 0 diff --git a/core/src/test/resources/HistoryServerExpectations/stage_with_speculation_summary_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_with_speculation_summary_expectation.json index 3ad18f816fbe5..23770480ad62a 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_with_speculation_summary_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_with_speculation_summary_expectation.json @@ -424,7 +424,9 @@ "MinorGCTime" : 280, "MajorGCCount" : 2, "MajorGCTime" : 1116, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isExcludedForStage" : false }, @@ -465,7 +467,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isExcludedForStage" : false }, @@ -506,7 +510,9 @@ "MinorGCTime" : 587, "MajorGCCount" : 2, "MajorGCTime" : 906, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isExcludedForStage" : false }, @@ -547,7 +553,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isExcludedForStage" : false } @@ -584,6 +592,8 @@ "MinorGCTime" : 587, "MajorGCCount" : 2, "MajorGCTime" : 1116, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 } } diff --git a/core/src/test/resources/HistoryServerExpectations/stage_with_summaries_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_with_summaries_expectation.json index c89b82caf3818..c8458a409589e 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_with_summaries_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_with_summaries_expectation.json @@ -1147,7 +1147,9 @@ "MinorGCTime" : 0, "MajorGCCount" : 0, "MajorGCTime" : 0, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isExcludedForStage" : false }, @@ -1188,7 +1190,9 @@ "MinorGCTime" : 115, "MajorGCCount" : 4, "MajorGCTime" : 339, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "isExcludedForStage" : false } @@ -1216,7 +1220,9 @@ "MinorGCTime" : 115, "MajorGCCount" : 4, "MajorGCTime" : 339, - "TotalGCTime" : 0 + "TotalGCTime" : 0, + "ConcurrentGCCount" : 0, + "ConcurrentGCTime" : 0 }, "taskMetricsDistributions" : { "quantiles" : [ 0.0, 0.25, 0.5, 0.75, 1.0 ], @@ -1306,7 +1312,9 @@ "MinorGCTime" : [ 0.0, 0.0, 115.0, 115.0, 115.0 ], "MajorGCCount" : [ 0.0, 0.0, 4.0, 4.0, 4.0 ], "MajorGCTime" : [ 0.0, 0.0, 339.0, 339.0, 339.0 ], - "TotalGCTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ] + "TotalGCTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "ConcurrentGCCount" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "ConcurrentGCTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ] } }, "isShufflePushEnabled" : false, diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 8105df64705a4..e8d41c4d46e21 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -58,21 +58,21 @@ class JsonProtocolSuite extends SparkFunSuite { makeTaskInfo(123L, 234, 67, 234, 345L, false), new ExecutorMetrics(Array(543L, 123456L, 12345L, 1234L, 123L, 12L, 432L, 321L, 654L, 765L, 256912L, 123456L, 123456L, 61728L, 30364L, 15182L, - 0, 0, 0, 0, 80001L)), + 0, 0, 0, 0, 80001L, 3, 3)), makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, 0, hasHadoopInput = false, hasOutput = false)) val taskEndWithHadoopInput = SparkListenerTaskEnd(1, 0, "ShuffleMapTask", Success, makeTaskInfo(123L, 234, 67, 234, 345L, false), new ExecutorMetrics(Array(543L, 123456L, 12345L, 1234L, 123L, 12L, 432L, 321L, 654L, 765L, 256912L, 123456L, 123456L, 61728L, 30364L, 15182L, - 0, 0, 0, 0, 80001L)), + 0, 0, 0, 0, 80001L, 3, 3)), makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, 0, hasHadoopInput = true, hasOutput = false)) val taskEndWithOutput = SparkListenerTaskEnd(1, 0, "ResultTask", Success, makeTaskInfo(123L, 234, 67, 234, 345L, false), new ExecutorMetrics(Array(543L, 123456L, 12345L, 1234L, 123L, 12L, 432L, 321L, 654L, 765L, 256912L, 123456L, 123456L, 61728L, 30364L, 15182L, - 0, 0, 0, 0, 80001L)), + 0, 0, 0, 0, 80001L, 3, 3)), makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, 0, hasHadoopInput = true, hasOutput = true)) val jobStart = { @@ -136,7 +136,7 @@ class JsonProtocolSuite extends SparkFunSuite { val executorUpdates = new ExecutorMetrics( Array(543L, 123456L, 12345L, 1234L, 123L, 12L, 432L, 321L, 654L, 765L, 256912L, 123456L, 123456L, 61728L, - 30364L, 15182L, 10L, 90L, 2L, 20L, 80001L)) + 30364L, 15182L, 10L, 90L, 2L, 20L, 80001L, 3, 3)) SparkListenerExecutorMetricsUpdate("exec3", Seq((1L, 2, 3, accumUpdates)), Map((0, 0) -> executorUpdates)) } @@ -147,7 +147,7 @@ class JsonProtocolSuite extends SparkFunSuite { SparkListenerStageExecutorMetrics("1", 2, 3, new ExecutorMetrics(Array(543L, 123456L, 12345L, 1234L, 123L, 12L, 432L, 321L, 654L, 765L, 256912L, 123456L, 123456L, 61728L, - 30364L, 15182L, 10L, 90L, 2L, 20L, 80001L))) + 30364L, 15182L, 10L, 90L, 2L, 20L, 80001L, 3, 3))) val rprofBuilder = new ResourceProfileBuilder() val taskReq = new TaskResourceRequests() .cpus(1) @@ -1754,7 +1754,9 @@ private[spark] object JsonProtocolSuite extends Assertions { | "MinorGCTime" : 0, | "MajorGCCount" : 0, | "MajorGCTime" : 0, - | "TotalGCTime" : 80001 + | "TotalGCTime": 80001, + | "ConcurrentGCCount" : 3, + | "ConcurrentGCTime" : 3 | }, | "Task Metrics": { | "Executor Deserialize Time": 300, @@ -1893,7 +1895,9 @@ private[spark] object JsonProtocolSuite extends Assertions { | "MinorGCTime" : 0, | "MajorGCCount" : 0, | "MajorGCTime" : 0, - | "TotalGCTime" : 80001 + | "TotalGCTime": 80001, + | "ConcurrentGCCount" : 3, + | "ConcurrentGCTime" : 3 | }, | "Task Metrics": { | "Executor Deserialize Time": 300, @@ -2032,7 +2036,9 @@ private[spark] object JsonProtocolSuite extends Assertions { | "MinorGCTime" : 0, | "MajorGCCount" : 0, | "MajorGCTime" : 0, - | "TotalGCTime" : 80001 + | "TotalGCTime": 80001, + | "ConcurrentGCCount" : 3, + | "ConcurrentGCTime" : 3 | }, | "Task Metrics": { | "Executor Deserialize Time": 300, @@ -2933,7 +2939,9 @@ private[spark] object JsonProtocolSuite extends Assertions { | "MinorGCTime": 90, | "MajorGCCount": 2, | "MajorGCTime": 20, - | "TotalGCTime" : 80001 + | "TotalGCTime": 80001, + | "ConcurrentGCCount" : 3, + | "ConcurrentGCTime" : 3 | } | } | ] @@ -2968,7 +2976,9 @@ private[spark] object JsonProtocolSuite extends Assertions { | "MinorGCTime": 90, | "MajorGCCount": 2, | "MajorGCTime": 20, - | "TotalGCTime" : 80001 + | "TotalGCTime": 80001, + | "ConcurrentGCCount" : 3, + | "ConcurrentGCTime" : 3 | } |} """.stripMargin From fd7acd32895ed79094ff75aef8ff133966627ee4 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Thu, 31 Aug 2023 23:19:41 -0500 Subject: [PATCH 12/35] [SPARK-44238][CORE][SQL] Introduce a new `readFrom` method with byte array input for `BloomFilter` ### What changes were proposed in this pull request? This pr introduce a new `readFrom` method with byte array input for `BloomFilter` ### Why are the changes needed? De-duplicate code ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass GitHub Actions Closes #41781 from LuciferYang/bloomfilter-readFrom. Lead-authored-by: yangjie01 Co-authored-by: YangJie Signed-off-by: Mridul Muralidharan gmail.com> --- .../java/org/apache/spark/util/sketch/BloomFilter.java | 7 +++++++ .../org/apache/spark/util/sketch/BloomFilterImpl.java | 6 ++++++ .../apache/spark/util/sketch/BloomFilterSuite.scala | 6 ++---- .../catalyst/expressions/BloomFilterMightContain.scala | 10 +--------- .../expressions/aggregate/BloomFilterAggregate.scala | 8 +------- 5 files changed, 17 insertions(+), 20 deletions(-) diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index f3c2b05e7af9d..172b394689ca9 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -178,6 +178,13 @@ public static BloomFilter readFrom(InputStream in) throws IOException { return BloomFilterImpl.readFrom(in); } + /** + * Reads in a {@link BloomFilter} from a byte array. + */ + public static BloomFilter readFrom(byte[] bytes) throws IOException { + return BloomFilterImpl.readFrom(bytes); + } + /** * Computes the optimal k (number of hashes per item inserted in Bloom filter), given the * expected insertions and total number of bits in the Bloom filter. diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java index ccf1833af9945..3fba5e3325223 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java @@ -266,6 +266,12 @@ public static BloomFilterImpl readFrom(InputStream in) throws IOException { return filter; } + public static BloomFilterImpl readFrom(byte[] bytes) throws IOException { + try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes)) { + return readFrom(bis); + } + } + private void writeObject(ObjectOutputStream out) throws IOException { writeTo(out); } diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala index cfdc9954772c5..4d0ba66637b46 100644 --- a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala +++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.util.sketch -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.io.ByteArrayOutputStream import scala.reflect.ClassTag import scala.util.Random @@ -34,9 +34,7 @@ class BloomFilterSuite extends AnyFunSuite { // scalastyle:ignore funsuite filter.writeTo(out) out.close() - val in = new ByteArrayInputStream(out.toByteArray) - val deserialized = BloomFilter.readFrom(in) - in.close() + val deserialized = BloomFilter.readFrom(out.toByteArray) assert(filter == deserialized) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala index b2273b6a6d13f..784bea899c4c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import java.io.ByteArrayInputStream - import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch @@ -119,11 +117,5 @@ case class BloomFilterMightContain( } } - final def deserialize(bytes: Array[Byte]): BloomFilter = { - val in = new ByteArrayInputStream(bytes) - val bloomFilter = BloomFilter.readFrom(in) - in.close() - bloomFilter - } - + final def deserialize(bytes: Array[Byte]): BloomFilter = BloomFilter.readFrom(bytes) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala index 7cba462ce2c3e..424e191a0c969 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import java.io.ByteArrayInputStream import java.io.ByteArrayOutputStream import org.apache.spark.sql.catalyst.InternalRow @@ -227,12 +226,7 @@ object BloomFilterAggregate { out.toByteArray } - final def deserialize(bytes: Array[Byte]): BloomFilter = { - val in = new ByteArrayInputStream(bytes) - val bloomFilter = BloomFilter.readFrom(in) - in.close() - bloomFilter - } + final def deserialize(bytes: Array[Byte]): BloomFilter = BloomFilter.readFrom(bytes) } private trait BloomFilterUpdater { From 9d28bef2f70b06cbb2f50a6814f8433fa344052e Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Fri, 1 Sep 2023 20:23:53 +0800 Subject: [PATCH 13/35] [SPARK-44743][SQL] Add `try_reflect` function ### What changes were proposed in this pull request? This PR add `try_reflect` function which binds to new expression `TryReflect` that is a runtime replaceable expression of the `CallMethodViaReflection`. ### Why are the changes needed? Add new `try_reflect` so invoke method failed will not cause job fail. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? add new test. ### Was this patch authored or co-authored using generative AI tooling? No Closes #42661 from Hisoka-X/SPARK-44743_hive_reflect. Authored-by: Jia Fan Signed-off-by: Wenchen Fan --- .../CheckConnectJvmClientCompatibility.scala | 2 + .../function_java_method.explain | 2 +- .../explain-results/function_reflect.explain | 2 +- python/pyspark/sql/tests/test_functions.py | 2 + .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/CallMethodViaReflection.scala | 18 +- .../sql/catalyst/expressions/TryEval.scala | 32 ++++ .../org/apache/spark/sql/functions.scala | 10 ++ .../sql-functions/sql-expression-schema.md | 1 + .../analyzer-results/try_reflect.sql.out | 146 +++++++++++++++ .../sql-tests/inputs/try_reflect.sql | 15 ++ .../sql-tests/results/try_reflect.sql.out | 170 ++++++++++++++++++ .../sql/expressions/ExpressionInfoSuite.scala | 1 + 13 files changed, 397 insertions(+), 5 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/analyzer-results/try_reflect.sql.out create mode 100644 sql/core/src/test/resources/sql-tests/inputs/try_reflect.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/try_reflect.sql.out diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 1e536cd37fec1..bf512ed71fd3f 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -208,6 +208,8 @@ object CheckConnectJvmClientCompatibility { // functions ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.unwrap_udt"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.udaf"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.functions.try_reflect"), // KeyValueGroupedDataset ProblemFilters.exclude[Problem]( diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_java_method.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_java_method.explain index 0d467be225f98..d2d5730eedf9e 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_java_method.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_java_method.explain @@ -1,2 +1,2 @@ -Project [java_method(java.util.UUID, fromString, g#0) AS java_method(java.util.UUID, fromString, g)#0] +Project [java_method(java.util.UUID, fromString, g#0, true) AS java_method(java.util.UUID, fromString, g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_reflect.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_reflect.explain index f52d3e1b0ff42..df790f0878062 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_reflect.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_reflect.explain @@ -1,2 +1,2 @@ -Project [reflect(java.util.UUID, fromString, g#0) AS reflect(java.util.UUID, fromString, g)#0] +Project [reflect(java.util.UUID, fromString, g#0, true) AS reflect(java.util.UUID, fromString, g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 5a8e36d287c08..218e8eb060ba4 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -87,6 +87,8 @@ def test_function_parity(self): # https://issues.apache.org/jira/browse/SPARK-44788 "from_xml", "schema_of_xml", + # TODO: reflect function will soon be added and removed from this list + "try_reflect", } self.assertEqual( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index eade7afb7cb52..af9c095deb96d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -454,6 +454,7 @@ object FunctionRegistry { expression[TryToBinary]("try_to_binary"), expressionBuilder("try_to_timestamp", TryToTimestampExpressionBuilder, setAlias = true), expression[TryAesDecrypt]("try_aes_decrypt"), + expression[TryReflect]("try_reflect"), // aggregate functions expression[HyperLogLogPlusPlus]("approx_count_distinct"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala index 52b057a327623..4f5ed0e13659a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import java.lang.reflect.{Method, Modifier} +import scala.util.control.NonFatal + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} @@ -55,11 +57,16 @@ import org.apache.spark.util.Utils """, since = "2.0.0", group = "misc_funcs") -case class CallMethodViaReflection(children: Seq[Expression]) +case class CallMethodViaReflection( + children: Seq[Expression], + failOnError: Boolean = true) extends Nondeterministic with CodegenFallback with QueryErrorsBase { + def this(children: Seq[Expression]) = + this(children, true) + override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("reflect") override def checkInputDataTypes(): TypeCheckResult = { @@ -133,8 +140,13 @@ case class CallMethodViaReflection(children: Seq[Expression]) } i += 1 } - val ret = method.invoke(null, buffer : _*) - UTF8String.fromString(String.valueOf(ret)) + try { + val ret = method.invoke(null, buffer : _*) + UTF8String.fromString(String.valueOf(ret)) + } catch { + case NonFatal(_) if !failOnError => + null + } } @transient private lazy val argExprs: Array[Expression] = children.drop(2).toArray diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala index a23f4f6194366..4eacd3442ed5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala @@ -236,3 +236,35 @@ case class TryToBinary( override protected def withNewChildInternal(newChild: Expression): Expression = this.copy(replacement = newChild) } + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(class, method[, arg1[, arg2 ..]]) - This is a special version of `reflect` that" + + " performs the same operation, but returns a NULL value instead of raising an error if the invoke method thrown exception.", + examples = """ + Examples: + > SELECT _FUNC_('java.util.UUID', 'randomUUID'); + c33fb387-8500-4bfa-81d2-6e0e3e930df2 + > SELECT _FUNC_('java.util.UUID', 'fromString', 'a5cf6c42-0c85-418f-af6c-3e4e5b1328f2'); + a5cf6c42-0c85-418f-af6c-3e4e5b1328f2 + > SELECT _FUNC_('java.net.URLDecoder', 'decode', '%'); + NULL + """, + since = "4.0.0", + group = "misc_funcs") +// scalastyle:on line.size.limit +case class TryReflect(params: Seq[Expression], replacement: Expression) extends RuntimeReplaceable + with InheritAnalysisRules { + + def this(params: Seq[Expression]) = this(params, + CallMethodViaReflection(params, failOnError = false)) + + override def prettyName: String = "try_reflect" + + override def parameters: Seq[Expression] = params + + override protected def withNewChildInternal(newChild: Expression): Expression = { + copy(replacement = newChild) + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d72191ce7f38e..9548f424ad407 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3690,6 +3690,16 @@ object functions { CallMethodViaReflection(cols.map(_.expr)) } + /** + * Calls a method with reflection. + * + * @group misc_funcs + * @since 4.0.0 + */ + def try_reflect(cols: Column*): Column = withExpr { + new TryReflect(cols.map(_.expr)) + } + /** * Returns the Spark version. The string contains 2 fields, the first being a release version * and the second being a git revision. diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 14b48db515b16..f518a67e1faed 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -341,6 +341,7 @@ | org.apache.spark.sql.catalyst.expressions.TryDivide | try_divide | SELECT try_divide(3, 2) | struct | | org.apache.spark.sql.catalyst.expressions.TryElementAt | try_element_at | SELECT try_element_at(array(1, 2, 3), 2) | struct | | org.apache.spark.sql.catalyst.expressions.TryMultiply | try_multiply | SELECT try_multiply(2, 3) | struct | +| org.apache.spark.sql.catalyst.expressions.TryReflect | try_reflect | SELECT try_reflect('java.util.UUID', 'randomUUID') | struct | | org.apache.spark.sql.catalyst.expressions.TrySubtract | try_subtract | SELECT try_subtract(2, 1) | struct | | org.apache.spark.sql.catalyst.expressions.TryToBinary | try_to_binary | SELECT try_to_binary('abc', 'utf-8') | struct | | org.apache.spark.sql.catalyst.expressions.TryToNumber | try_to_number | SELECT try_to_number('454', '999') | struct | diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/try_reflect.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/try_reflect.sql.out new file mode 100644 index 0000000000000..0b816cecf1a05 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/try_reflect.sql.out @@ -0,0 +1,146 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +SELECT try_reflect("java.util.UUID", "fromString", "a5cf6c42-0c85-418f-af6c-3e4e5b1328f2") +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT try_reflect("java.lang.String", "valueOf", 1) +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT try_reflect("java.lang.Math", "max", 2, 3) +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT try_reflect("java.lang.Math", "min", 2, 3) +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT try_reflect("java.lang.Integer", "valueOf", "10", 16) +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT try_reflect("java.util.UUID", "fromString", "b") +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT try_reflect("java.net.URLDecoder", "decode", "%") +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT try_reflect("java.wrongclass.Math", "max", 2, 3) +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_CLASS_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "className" : "java.wrongclass.Math", + "sqlExpr" : "\"reflect(java.wrongclass.Math, max, 2, 3)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 55, + "fragment" : "try_reflect(\"java.wrongclass.Math\", \"max\", 2, 3)" + } ] +} + + +-- !query +SELECT try_reflect("java.lang.Math", "wrongmethod", 2, 3) +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_STATIC_METHOD", + "sqlState" : "42K09", + "messageParameters" : { + "className" : "java.lang.Math", + "methodName" : "wrongmethod", + "sqlExpr" : "\"reflect(java.lang.Math, wrongmethod, 2, 3)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 57, + "fragment" : "try_reflect(\"java.lang.Math\", \"wrongmethod\", 2, 3)" + } ] +} + + +-- !query +SELECT try_reflect("java.lang.Math") +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + "sqlState" : "42605", + "messageParameters" : { + "actualNum" : "1", + "docroot" : "https://spark.apache.org/docs/latest", + "expectedNum" : "> 1", + "functionName" : "`reflect`" + } +} + + +-- !query +SELECT try_reflect("java.lang.Math", "round", 2.5) +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"2.5\"", + "inputType" : "\"DECIMAL(2,1)\"", + "paramIndex" : "3", + "requiredType" : "(\"BOOLEAN\" or \"TINYINT\" or \"SMALLINT\" or \"INT\" or \"BIGINT\" or \"FLOAT\" or \"DOUBLE\" or \"STRING\")", + "sqlExpr" : "\"reflect(java.lang.Math, round, 2.5)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 50, + "fragment" : "try_reflect(\"java.lang.Math\", \"round\", 2.5)" + } ] +} + + +-- !query +SELECT try_reflect("java.lang.Object", "toString") +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_STATIC_METHOD", + "sqlState" : "42K09", + "messageParameters" : { + "className" : "java.lang.Object", + "methodName" : "toString", + "sqlExpr" : "\"reflect(java.lang.Object, toString)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 50, + "fragment" : "try_reflect(\"java.lang.Object\", \"toString\")" + } ] +} diff --git a/sql/core/src/test/resources/sql-tests/inputs/try_reflect.sql b/sql/core/src/test/resources/sql-tests/inputs/try_reflect.sql new file mode 100644 index 0000000000000..dd2bce7ef1f8c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/try_reflect.sql @@ -0,0 +1,15 @@ +-- positive +SELECT try_reflect("java.util.UUID", "fromString", "a5cf6c42-0c85-418f-af6c-3e4e5b1328f2"); +SELECT try_reflect("java.lang.String", "valueOf", 1); +SELECT try_reflect("java.lang.Math", "max", 2, 3); +SELECT try_reflect("java.lang.Math", "min", 2, 3); +SELECT try_reflect("java.lang.Integer", "valueOf", "10", 16); + +-- negative +SELECT try_reflect("java.util.UUID", "fromString", "b"); +SELECT try_reflect("java.net.URLDecoder", "decode", "%"); +SELECT try_reflect("java.wrongclass.Math", "max", 2, 3); +SELECT try_reflect("java.lang.Math", "wrongmethod", 2, 3); +SELECT try_reflect("java.lang.Math"); +SELECT try_reflect("java.lang.Math", "round", 2.5); +SELECT try_reflect("java.lang.Object", "toString"); diff --git a/sql/core/src/test/resources/sql-tests/results/try_reflect.sql.out b/sql/core/src/test/resources/sql-tests/results/try_reflect.sql.out new file mode 100644 index 0000000000000..13da0edca9898 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/try_reflect.sql.out @@ -0,0 +1,170 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +SELECT try_reflect("java.util.UUID", "fromString", "a5cf6c42-0c85-418f-af6c-3e4e5b1328f2") +-- !query schema +struct +-- !query output +a5cf6c42-0c85-418f-af6c-3e4e5b1328f2 + + +-- !query +SELECT try_reflect("java.lang.String", "valueOf", 1) +-- !query schema +struct +-- !query output +1 + + +-- !query +SELECT try_reflect("java.lang.Math", "max", 2, 3) +-- !query schema +struct +-- !query output +3 + + +-- !query +SELECT try_reflect("java.lang.Math", "min", 2, 3) +-- !query schema +struct +-- !query output +2 + + +-- !query +SELECT try_reflect("java.lang.Integer", "valueOf", "10", 16) +-- !query schema +struct +-- !query output +16 + + +-- !query +SELECT try_reflect("java.util.UUID", "fromString", "b") +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT try_reflect("java.net.URLDecoder", "decode", "%") +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT try_reflect("java.wrongclass.Math", "max", 2, 3) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_CLASS_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "className" : "java.wrongclass.Math", + "sqlExpr" : "\"reflect(java.wrongclass.Math, max, 2, 3)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 55, + "fragment" : "try_reflect(\"java.wrongclass.Math\", \"max\", 2, 3)" + } ] +} + + +-- !query +SELECT try_reflect("java.lang.Math", "wrongmethod", 2, 3) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_STATIC_METHOD", + "sqlState" : "42K09", + "messageParameters" : { + "className" : "java.lang.Math", + "methodName" : "wrongmethod", + "sqlExpr" : "\"reflect(java.lang.Math, wrongmethod, 2, 3)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 57, + "fragment" : "try_reflect(\"java.lang.Math\", \"wrongmethod\", 2, 3)" + } ] +} + + +-- !query +SELECT try_reflect("java.lang.Math") +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + "sqlState" : "42605", + "messageParameters" : { + "actualNum" : "1", + "docroot" : "https://spark.apache.org/docs/latest", + "expectedNum" : "> 1", + "functionName" : "`reflect`" + } +} + + +-- !query +SELECT try_reflect("java.lang.Math", "round", 2.5) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"2.5\"", + "inputType" : "\"DECIMAL(2,1)\"", + "paramIndex" : "3", + "requiredType" : "(\"BOOLEAN\" or \"TINYINT\" or \"SMALLINT\" or \"INT\" or \"BIGINT\" or \"FLOAT\" or \"DOUBLE\" or \"STRING\")", + "sqlExpr" : "\"reflect(java.lang.Math, round, 2.5)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 50, + "fragment" : "try_reflect(\"java.lang.Math\", \"round\", 2.5)" + } ] +} + + +-- !query +SELECT try_reflect("java.lang.Object", "toString") +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_STATIC_METHOD", + "sqlState" : "42K09", + "messageParameters" : { + "className" : "java.lang.Object", + "methodName" : "toString", + "sqlExpr" : "\"reflect(java.lang.Object, toString)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 50, + "fragment" : "try_reflect(\"java.lang.Object\", \"toString\")" + } ] +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala index 4dd93983e87e3..262412f8cdb78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala @@ -190,6 +190,7 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { "org.apache.spark.sql.catalyst.expressions.InputFileBlockLength", // The example calls methods that return unstable results. "org.apache.spark.sql.catalyst.expressions.CallMethodViaReflection", + "org.apache.spark.sql.catalyst.expressions.TryReflect", "org.apache.spark.sql.catalyst.expressions.SparkVersion", // Throws an error "org.apache.spark.sql.catalyst.expressions.RaiseError", From 00f66994c802faf9ccc0d40ed4f6ff32992ba00f Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Fri, 1 Sep 2023 20:27:17 +0800 Subject: [PATCH 14/35] [SPARK-44577][SQL] Fix INSERT BY NAME returns nonsensical error message ### What changes were proposed in this pull request? Fix INSERT BY NAME returns nonsensical error message on v1 datasource. eg: ```scala CREATE TABLE bug(c1 INT); INSERT INTO bug BY NAME SELECT 1 AS c2; ==> Multi-part identifier cannot be empty. ``` After PR: ```scala [INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA] Cannot write incompatible data for the table `spark_catalog`.`default`.`bug`: Cannot find data for the output column `c1`. ``` Also fixed the same issue when throwing other INCOMPATIBLE_DATA_FOR_TABLE type errors ### Why are the changes needed? Fix the error msg nonsensical. ### Does this PR introduce _any_ user-facing change? Yes, the error msg in v1 insert by name will be changed. ### How was this patch tested? add new test. Closes #42220 from Hisoka-X/SPARK-44577_insert_by_name_bug_fix. Authored-by: Jia Fan Signed-off-by: Wenchen Fan --- .../main/resources/error/error-classes.json | 5 +++ ...incompatible-data-for-table-error-class.md | 4 +++ .../analysis/ResolveInsertionBase.scala | 4 +-- .../analysis/TableOutputResolver.scala | 36 +++++++++++++------ .../sql/errors/QueryCompilationErrors.scala | 19 +++++++--- .../apache/spark/sql/SQLInsertTestSuite.scala | 5 +-- 6 files changed, 54 insertions(+), 19 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index af78dd2f9f801..87b9da7638b2a 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -1035,6 +1035,11 @@ "Cannot safely cast to ." ] }, + "EXTRA_COLUMNS" : { + "message" : [ + "Cannot write extra columns ." + ] + }, "EXTRA_STRUCT_FIELDS" : { "message" : [ "Cannot write extra fields to the struct ." diff --git a/docs/sql-error-conditions-incompatible-data-for-table-error-class.md b/docs/sql-error-conditions-incompatible-data-for-table-error-class.md index f70b69ba6c5bd..0dd28e9d55c50 100644 --- a/docs/sql-error-conditions-incompatible-data-for-table-error-class.md +++ b/docs/sql-error-conditions-incompatible-data-for-table-error-class.md @@ -37,6 +37,10 @@ Cannot find data for the output column ``. Cannot safely cast `` `` to ``. +## EXTRA_COLUMNS + +Cannot write extra columns ``. + ## EXTRA_STRUCT_FIELDS Cannot write extra fields `` to the struct ``. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInsertionBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInsertionBase.scala index 8b120095bc600..ad89005a093e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInsertionBase.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInsertionBase.scala @@ -36,10 +36,10 @@ abstract class ResolveInsertionBase extends Rule[LogicalPlan] { if (i.userSpecifiedCols.size != i.query.output.size) { if (i.userSpecifiedCols.size > i.query.output.size) { throw QueryCompilationErrors.cannotWriteNotEnoughColumnsToTableError( - tblName, i.userSpecifiedCols, i.query) + tblName, i.userSpecifiedCols, i.query.output) } else { throw QueryCompilationErrors.cannotWriteTooManyColumnsToTableError( - tblName, i.userSpecifiedCols, i.query) + tblName, i.userSpecifiedCols, i.query.output) } } val projectByName = i.userSpecifiedCols.zip(i.query.output) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index 21575f7b96bed..fc0e727bea529 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -87,7 +87,7 @@ object TableOutputResolver { if (actualExpectedCols.size < query.output.size) { throw QueryCompilationErrors.cannotWriteTooManyColumnsToTableError( - tableName, actualExpectedCols.map(_.name), query) + tableName, actualExpectedCols.map(_.name), query.output) } val errors = new mutable.ArrayBuffer[String]() @@ -105,7 +105,7 @@ object TableOutputResolver { } else { if (actualExpectedCols.size > query.output.size) { throw QueryCompilationErrors.cannotWriteNotEnoughColumnsToTableError( - tableName, actualExpectedCols.map(_.name), query) + tableName, actualExpectedCols.map(_.name), query.output) } resolveColumnsByPosition(tableName, query.output, actualExpectedCols, conf, errors += _) } @@ -267,9 +267,13 @@ object TableOutputResolver { if (matchedCols.size < inputCols.length) { val extraCols = inputCols.filterNot(col => matchedCols.contains(col.name)) .map(col => s"${toSQLId(col.name)}").mkString(", ") - throw QueryCompilationErrors.incompatibleDataToTableExtraStructFieldsError( - tableName, colPath.quoted, extraCols - ) + if (colPath.isEmpty) { + throw QueryCompilationErrors.incompatibleDataToTableExtraColumnsError(tableName, + extraCols) + } else { + throw QueryCompilationErrors.incompatibleDataToTableExtraStructFieldsError( + tableName, colPath.quoted, extraCols) + } } else { reordered } @@ -290,16 +294,26 @@ object TableOutputResolver { val extraColsStr = inputCols.takeRight(inputCols.size - expectedCols.size) .map(col => toSQLId(col.name)) .mkString(", ") - throw QueryCompilationErrors.incompatibleDataToTableExtraStructFieldsError( - tableName, colPath.quoted, extraColsStr - ) + if (colPath.isEmpty) { + throw QueryCompilationErrors.cannotWriteTooManyColumnsToTableError(tableName, + expectedCols.map(_.name), inputCols.map(_.toAttribute)) + } else { + throw QueryCompilationErrors.incompatibleDataToTableExtraStructFieldsError( + tableName, colPath.quoted, extraColsStr + ) + } } else if (inputCols.size < expectedCols.size) { val missingColsStr = expectedCols.takeRight(expectedCols.size - inputCols.size) .map(col => toSQLId(col.name)) .mkString(", ") - throw QueryCompilationErrors.incompatibleDataToTableStructMissingFieldsError( - tableName, colPath.quoted, missingColsStr - ) + if (colPath.isEmpty) { + throw QueryCompilationErrors.cannotWriteNotEnoughColumnsToTableError(tableName, + expectedCols.map(_.name), inputCols.map(_.toAttribute)) + } else { + throw QueryCompilationErrors.incompatibleDataToTableStructMissingFieldsError( + tableName, colPath.quoted, missingColsStr + ) + } } inputCols.zip(expectedCols).flatMap { case (inputCol, expectedCol) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index a97abf8943406..ca101e79d9211 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -2159,25 +2159,25 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat def cannotWriteTooManyColumnsToTableError( tableName: String, expected: Seq[String], - query: LogicalPlan): Throwable = { + queryOutput: Seq[Attribute]): Throwable = { new AnalysisException( errorClass = "INSERT_COLUMN_ARITY_MISMATCH.TOO_MANY_DATA_COLUMNS", messageParameters = Map( "tableName" -> toSQLId(tableName), "tableColumns" -> expected.map(c => toSQLId(c)).mkString(", "), - "dataColumns" -> query.output.map(c => toSQLId(c.name)).mkString(", "))) + "dataColumns" -> queryOutput.map(c => toSQLId(c.name)).mkString(", "))) } def cannotWriteNotEnoughColumnsToTableError( tableName: String, expected: Seq[String], - query: LogicalPlan): Throwable = { + queryOutput: Seq[Attribute]): Throwable = { new AnalysisException( errorClass = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", messageParameters = Map( "tableName" -> toSQLId(tableName), "tableColumns" -> expected.map(c => toSQLId(c)).mkString(", "), - "dataColumns" -> query.output.map(c => toSQLId(c.name)).mkString(", "))) + "dataColumns" -> queryOutput.map(c => toSQLId(c.name)).mkString(", "))) } def incompatibleDataToTableCannotFindDataError( @@ -2202,6 +2202,17 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat ) } + def incompatibleDataToTableExtraColumnsError( + tableName: String, extraColumns: String): Throwable = { + new AnalysisException( + errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_COLUMNS", + messageParameters = Map( + "tableName" -> toSQLId(tableName), + "extraColumns" -> extraColumns + ) + ) + } + def incompatibleDataToTableExtraStructFieldsError( tableName: String, colName: String, extraFields: String): Throwable = { new AnalysisException( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala index 0bbed51d0a908..34e4ded09b5fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala @@ -213,9 +213,10 @@ trait SQLInsertTestSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { processInsert("t1", df, overwrite = false, byName = true) }, - v1ErrorClass = "_LEGACY_ERROR_TEMP_1186", + v1ErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_COLUMNS", v2ErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", - v1Parameters = Map.empty[String, String], + v1Parameters = Map("tableName" -> "`spark_catalog`.`default`.`t1`", + "extraColumns" -> "`x1`"), v2Parameters = Map("tableName" -> "`testcat`.`t1`", "colName" -> "`c1`") ) val df2 = Seq((3, 2, 1, 0)).toDF(Seq("c3", "c2", "c1", "c0"): _*) From e4114f67e12a235b4784fcbfa6f6ba9b44a5e715 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Fri, 1 Sep 2023 22:15:23 +0800 Subject: [PATCH 15/35] [SPARK-45048][CONNECT] Add additional tests for Python client and attachable execution ### What changes were proposed in this pull request? For better test coverage add additional tests of the attachable Spark Connect Python client. ### Why are the changes needed? Stability ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? New test ### Was this patch authored or co-authored using generative AI tooling? No Closes #42769 from grundprinzip/SPARK-45048. Authored-by: Martin Grund Signed-off-by: Ruifeng Zheng --- .../sql/tests/connect/client/test_client.py | 156 +++++++++++++++++- 1 file changed, 154 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index 2ba42cabf84a5..70280c1d24a7e 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -17,16 +17,21 @@ import unittest import uuid -from typing import Optional +from collections.abc import Generator +from typing import Optional, Any from pyspark.testing.connectutils import should_test_connect, connect_requirement_message if should_test_connect: + import grpc import pandas as pd import pyarrow as pa from pyspark.sql.connect.client import SparkConnectClient, ChannelBuilder from pyspark.sql.connect.client.core import Retrying - from pyspark.sql.connect.client.reattach import RetryException + from pyspark.sql.connect.client.reattach import ( + RetryException, + ExecutePlanResponseReattachableIterator, + ) import pyspark.sql.connect.proto as proto @@ -119,6 +124,153 @@ def test_channel_builder_with_session(self): self.assertEqual(client._session_id, chan.session_id) +@unittest.skipIf(not should_test_connect, connect_requirement_message) +class SparkConnectClientReattachTestCase(unittest.TestCase): + def setUp(self) -> None: + self.request = proto.ExecutePlanRequest() + self.policy = { + "max_retries": 3, + "backoff_multiplier": 4.0, + "initial_backoff": 10, + "max_backoff": 10, + "jitter": 10, + "min_jitter_threshold": 10, + } + self.response = proto.ExecutePlanResponse() + self.finished = proto.ExecutePlanResponse( + result_complete=proto.ExecutePlanResponse.ResultComplete() + ) + + def _stub_with(self, execute=None, attach=None): + return MockSparkConnectStub( + execute_ops=ResponseGenerator(execute) if execute is not None else None, + attach_ops=ResponseGenerator(attach) if attach is not None else None, + ) + + def test_basic_flow(self): + stub = self._stub_with([self.response, self.finished]) + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.policy, []) + for b in ite: + pass + + self.assertEqual(0, stub.attach_calls) + self.assertGreater(1, stub.release_calls) + self.assertEqual(1, stub.execute_calls) + + def test_fail_during_execute(self): + def fatal(): + raise TestException("Fatal") + + stub = self._stub_with([self.response, fatal]) + with self.assertRaises(TestException): + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.policy, []) + for b in ite: + pass + + self.assertEqual(0, stub.attach_calls) + self.assertEqual(0, stub.release_calls) + self.assertEqual(1, stub.execute_calls) + + def test_fail_and_retry_during_execute(self): + def non_fatal(): + raise TestException("Non Fatal", grpc.StatusCode.UNAVAILABLE) + + stub = self._stub_with( + [self.response, non_fatal], [self.response, self.response, self.finished] + ) + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.policy, []) + for b in ite: + pass + + self.assertEqual(1, stub.attach_calls) + self.assertEqual(1, stub.release_calls) + self.assertEqual(1, stub.execute_calls) + + def test_fail_and_retry_during_reattach(self): + count = 0 + + def non_fatal(): + nonlocal count + if count < 2: + count += 1 + raise TestException("Non Fatal", grpc.StatusCode.UNAVAILABLE) + else: + return proto.ExecutePlanResponse() + + stub = self._stub_with( + [self.response, non_fatal], [self.response, non_fatal, self.response, self.finished] + ) + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.policy, []) + for b in ite: + pass + + self.assertEqual(2, stub.attach_calls) + self.assertEqual(2, stub.release_calls) + self.assertEqual(1, stub.execute_calls) + + +class TestException(grpc.RpcError, grpc.Call): + """Exception mock to test retryable exceptions.""" + + def __init__(self, msg, code=grpc.StatusCode.INTERNAL): + self.msg = msg + self._code = code + + def code(self): + return self._code + + def __str__(self): + return self.msg + + def trailing_metadata(self): + return () + + +class ResponseGenerator(Generator): + """This class is used to generate values that are returned by the streaming + iterator of the GRPC stub.""" + + def __init__(self, funs): + self._funs = funs + self._iterator = iter(self._funs) + + def send(self, value: Any) -> proto.ExecutePlanResponse: + val = next(self._iterator) + if callable(val): + return val() + else: + return val + + def throw(self, type: Any = None, value: Any = None, traceback: Any = None) -> Any: + super().throw(type, value, traceback) + + def close(self) -> None: + return super().close() + + +class MockSparkConnectStub: + """Simple mock class for the GRPC stub used by the re-attachable execution.""" + + def __init__(self, execute_ops=None, attach_ops=None): + self._execute_ops = execute_ops + self._attach_ops = attach_ops + # Call counters + self.execute_calls = 0 + self.release_calls = 0 + self.attach_calls = 0 + + def ExecutePlan(self, *args, **kwargs): + self.execute_calls += 1 + return self._execute_ops + + def ReattachExecute(self, *args, **kwargs): + self.attach_calls += 1 + return self._attach_ops + + def ReleaseExecute(self, *args, **kwargs): + self.release_calls += 1 + + class MockService: # Simplest mock of the SparkConnectService. # If this needs more complex logic, it needs to be replaced with Python mocking. From 71d531300967655395250a3371c865c9d39f14b4 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Fri, 1 Sep 2023 09:58:20 -0700 Subject: [PATCH 16/35] [SPARK-44942][INFRA] Use Jira notification options to sync with Github ### What changes were proposed in this pull request? The dev/github_jira_sync.py does not work well these days, this PR tries to solve this issue by using a much more reliable service - https://cwiki.apache.org/confluence/display/INFRA/Git+-+.asf.yaml+features#Git.asf.yamlfeatures-Jiranotificationoptions And maybe we can remove dev/github_jira_sync.py but I don't know where we run this script yet. ### Why are the changes needed? for jira and github sync ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? Took a tour of a few Apache projects ### Was this patch authored or co-authored using generative AI tooling? no Closes #42750 from yaooqinn/SPARK-44942. Authored-by: Kent Yao Signed-off-by: Dongjoon Hyun --- .asf.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.asf.yaml b/.asf.yaml index ae5e99cf230d8..22042b355b2fa 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -36,3 +36,4 @@ notifications: pullrequests: reviews@spark.apache.org issues: reviews@spark.apache.org commits: commits@spark.apache.org + jira_options: link label From 8c27de68756d4b0e5940211340a0b323d808aead Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 1 Sep 2023 10:01:36 -0700 Subject: [PATCH 17/35] [SPARK-44750][PYTHON][CONNECT][TESTS][FOLLOW-UP] Avoid creating session twice in `SparkConnectSessionWithOptionsTest` ### What changes were proposed in this pull request? Avoid creating session twice in `SparkConnectSessionWithOptionsTest` ### Why are the changes needed? the session created in `ReusedConnectTestCase#setUpClass` is not used, so no need to inherit ### Does this PR introduce _any_ user-facing change? no, test-only ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #42747 from zhengruifeng/minor_test_ut. Authored-by: Ruifeng Zheng Signed-off-by: Dongjoon Hyun --- python/pyspark/sql/tests/connect/test_connect_basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 65560725b46a5..2b97957061898 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -3345,7 +3345,7 @@ def test_can_create_multiple_sessions_to_different_remotes(self): self.assertIn("Create a new SparkSession is only supported with SparkConnect.", str(e)) -class SparkConnectSessionWithOptionsTest(ReusedConnectTestCase): +class SparkConnectSessionWithOptionsTest(unittest.TestCase): def setUp(self) -> None: self.spark = ( PySparkSession.builder.config("string", "foo") From e86849a84323bced40147bbea20eeb72924b11ca Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Fri, 1 Sep 2023 10:09:16 -0700 Subject: [PATCH 18/35] [SPARK-45037][INFRA] Upload unit tests log files for timeouted cancel ### What changes were proposed in this pull request? We currently upload ut logs on failures, while if there is a canceled state, there could be problems, too. We need those logs for bug hunting. ### Why are the changes needed? There may be evidence of a timeout cancellation in the logs. reference: https://docs.github.com/en/actions/learn-github-actions/expressions#status-check-functions ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? by github actions ### Was this patch authored or co-authored using generative AI tooling? no Closes #42756 from yaooqinn/SPARK-45037. Authored-by: Kent Yao Signed-off-by: Dongjoon Hyun --- .github/workflows/build_and_test.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index c0c21c7cb6eb0..beb5a7772b7f8 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -283,7 +283,7 @@ jobs: name: test-results-${{ matrix.modules }}-${{ matrix.comment }}-${{ matrix.java }}-${{ matrix.hadoop }}-${{ matrix.hive }} path: "**/target/test-reports/*.xml" - name: Upload unit tests log files - if: failure() + if: ${{ !success() }} uses: actions/upload-artifact@v3 with: name: unit-tests-log-${{ matrix.modules }}-${{ matrix.comment }}-${{ matrix.java }}-${{ matrix.hadoop }}-${{ matrix.hive }} @@ -470,7 +470,7 @@ jobs: name: test-results-${{ matrix.modules }}--8-${{ inputs.hadoop }}-hive2.3 path: "**/target/test-reports/*.xml" - name: Upload unit tests log files - if: failure() + if: ${{ !success() }} uses: actions/upload-artifact@v3 with: name: unit-tests-log-${{ matrix.modules }}--8-${{ inputs.hadoop }}-hive2.3 @@ -961,7 +961,7 @@ jobs: name: test-results-tpcds--8-${{ inputs.hadoop }}-hive2.3 path: "**/target/test-reports/*.xml" - name: Upload unit tests log files - if: failure() + if: ${{ !success() }} uses: actions/upload-artifact@v3 with: name: unit-tests-log-tpcds--8-${{ inputs.hadoop }}-hive2.3 @@ -1028,7 +1028,7 @@ jobs: name: test-results-docker-integration--8-${{ inputs.hadoop }}-hive2.3 path: "**/target/test-reports/*.xml" - name: Upload unit tests log files - if: failure() + if: ${{ !success() }} uses: actions/upload-artifact@v3 with: name: unit-tests-log-docker-integration--8-${{ inputs.hadoop }}-hive2.3 @@ -1103,7 +1103,7 @@ jobs: eval $(minikube docker-env) build/sbt -Psparkr -Pkubernetes -Pvolcano -Pkubernetes-integration-tests -Dspark.kubernetes.test.driverRequestCores=0.5 -Dspark.kubernetes.test.executorRequestCores=0.2 -Dspark.kubernetes.test.volcanoMaxConcurrencyJobNum=1 -Dtest.exclude.tags=local "kubernetes-integration-tests/test" - name: Upload Spark on K8S integration tests log files - if: failure() + if: ${{ !success() }} uses: actions/upload-artifact@v3 with: name: spark-on-kubernetes-it-log From df534c355d9059fb5b128491a8f037baa121cbd7 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 1 Sep 2023 10:41:37 -0700 Subject: [PATCH 19/35] [SPARK-44952][SQL][PYTHON] Support named arguments in aggregate Pandas UDFs ### What changes were proposed in this pull request? Supports named arguments in aggregate Pandas UDFs. For example: ```py >>> pandas_udf("double") ... def weighted_mean(v: pd.Series, w: pd.Series) -> float: ... import numpy as np ... return np.average(v, weights=w) ... >>> df = spark.createDataFrame( ... [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2, 10.0, 3.0)], ... ("id", "v", "w")) >>> df.groupby("id").agg(weighted_mean(v=df["v"], w=df["w"])).show() +---+-----------------------------+ | id|weighted_mean(v => v, w => w)| +---+-----------------------------+ | 1| 1.6666666666666667| | 2| 7.166666666666667| +---+-----------------------------+ >>> df.groupby("id").agg(weighted_mean(w=df["w"], v=df["v"])).show() +---+-----------------------------+ | id|weighted_mean(w => w, v => v)| +---+-----------------------------+ | 1| 1.6666666666666667| | 2| 7.166666666666667| +---+-----------------------------+ ``` or with window: ```py >>> w = Window.partitionBy("id").orderBy("v").rowsBetween(-2, 1) >>> df.withColumn("wm", weighted_mean(v=df.v, w=df.w).over(w)).show() +---+----+---+------------------+ | id| v| w| wm| +---+----+---+------------------+ | 1| 1.0|1.0|1.6666666666666667| | 1| 2.0|2.0|1.6666666666666667| | 2| 3.0|1.0| 4.333333333333333| | 2| 5.0|2.0| 7.166666666666667| | 2|10.0|3.0| 7.166666666666667| +---+----+---+------------------+ >>> df.withColumn("wm", weighted_mean_udf(w=df.w, v=df.v).over(w)).show() +---+----+---+------------------+ | id| v| w| wm| +---+----+---+------------------+ | 1| 1.0|1.0|1.6666666666666667| | 1| 2.0|2.0|1.6666666666666667| | 2| 3.0|1.0| 4.333333333333333| | 2| 5.0|2.0| 7.166666666666667| | 2|10.0|3.0| 7.166666666666667| +---+----+---+------------------+ ``` ### Why are the changes needed? Now that named arguments support was added (https://github.com/apache/spark/pull/41796, https://github.com/apache/spark/pull/42020). Aggregate Pandas UDFs can support it. ### Does this PR introduce _any_ user-facing change? Yes, named arguments will be available for aggregate Pandas UDFs. ### How was this patch tested? Added related tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #42663 from ueshin/issues/SPARK-44952/kwargs. Authored-by: Takuya UESHIN Signed-off-by: Takuya UESHIN --- python/pyspark/sql/pandas/functions.py | 20 +- .../pandas/test_pandas_udf_grouped_agg.py | 147 ++++++++++++++- .../tests/pandas/test_pandas_udf_window.py | 173 +++++++++++++++++- python/pyspark/sql/tests/test_udf.py | 15 ++ python/pyspark/sql/tests/test_udtf.py | 15 ++ python/pyspark/worker.py | 25 +-- .../sql/catalyst/analysis/Analyzer.scala | 11 +- .../python/AggregateInPandasExec.scala | 23 ++- .../python/UserDefinedPythonFunction.scala | 3 +- .../WindowInPandasEvaluatorFactory.scala | 37 ++-- 10 files changed, 429 insertions(+), 40 deletions(-) diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py index ad9fdac970639..652129180df94 100644 --- a/python/pyspark/sql/pandas/functions.py +++ b/python/pyspark/sql/pandas/functions.py @@ -57,7 +57,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): Supports Spark Connect. .. versionchanged:: 4.0.0 - Supports keyword-arguments in SCALAR type. + Supports keyword-arguments in SCALAR and GROUPED_AGG type. Parameters ---------- @@ -267,6 +267,24 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: | 2| 6.0| +---+-----------+ + This type of Pandas UDF can use keyword arguments: + + >>> @pandas_udf("double") + ... def weighted_mean_udf(v: pd.Series, w: pd.Series) -> float: + ... import numpy as np + ... return np.average(v, weights=w) + ... + >>> df = spark.createDataFrame( + ... [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2, 10.0, 3.0)], + ... ("id", "v", "w")) + >>> df.groupby("id").agg(weighted_mean_udf(w=df["w"], v=df["v"])).show() + +---+---------------------------------+ + | id|weighted_mean_udf(w => w, v => v)| + +---+---------------------------------+ + | 1| 1.6666666666666667| + | 2| 7.166666666666667| + +---+---------------------------------+ + This UDF can also be used as window functions as below: >>> from pyspark.sql import Window diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py index f434489a6fb88..b500be7a96957 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py @@ -32,7 +32,7 @@ PandasUDFType, ) from pyspark.sql.types import ArrayType, YearMonthIntervalType -from pyspark.errors import AnalysisException, PySparkNotImplementedError +from pyspark.errors import AnalysisException, PySparkNotImplementedError, PythonException from pyspark.testing.sqlutils import ( ReusedSQLTestCase, have_pandas, @@ -40,7 +40,7 @@ pandas_requirement_message, pyarrow_requirement_message, ) -from pyspark.testing.utils import QuietTest +from pyspark.testing.utils import QuietTest, assertDataFrameEqual if have_pandas: @@ -575,6 +575,149 @@ def mean(x): assert filtered.collect()[0]["mean"] == 42.0 + def test_named_arguments(self): + df = self.data + weighted_mean = self.pandas_agg_weighted_mean_udf + + with self.tempView("v"): + df.createOrReplaceTempView("v") + self.spark.udf.register("weighted_mean", weighted_mean) + + for i, aggregated in enumerate( + [ + df.groupby("id").agg(weighted_mean(df.v, w=df.w).alias("wm")), + df.groupby("id").agg(weighted_mean(v=df.v, w=df.w).alias("wm")), + df.groupby("id").agg(weighted_mean(w=df.w, v=df.v).alias("wm")), + self.spark.sql("SELECT id, weighted_mean(v, w => w) as wm FROM v GROUP BY id"), + self.spark.sql( + "SELECT id, weighted_mean(v => v, w => w) as wm FROM v GROUP BY id" + ), + self.spark.sql( + "SELECT id, weighted_mean(w => w, v => v) as wm FROM v GROUP BY id" + ), + ] + ): + with self.subTest(query_no=i): + assertDataFrameEqual(aggregated, df.groupby("id").agg(mean(df.v).alias("wm"))) + + def test_named_arguments_negative(self): + df = self.data + weighted_mean = self.pandas_agg_weighted_mean_udf + + with self.tempView("v"): + df.createOrReplaceTempView("v") + self.spark.udf.register("weighted_mean", weighted_mean) + + with self.assertRaisesRegex( + AnalysisException, + "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + ): + self.spark.sql( + "SELECT id, weighted_mean(v => v, v => w) as wm FROM v GROUP BY id" + ).show() + + with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"): + self.spark.sql( + "SELECT id, weighted_mean(v => v, w) as wm FROM v GROUP BY id" + ).show() + + with self.assertRaisesRegex( + PythonException, r"weighted_mean\(\) got an unexpected keyword argument 'x'" + ): + self.spark.sql( + "SELECT id, weighted_mean(v => v, x => w) as wm FROM v GROUP BY id" + ).show() + + with self.assertRaisesRegex( + PythonException, r"weighted_mean\(\) got multiple values for argument 'v'" + ): + self.spark.sql( + "SELECT id, weighted_mean(v, v => w) as wm FROM v GROUP BY id" + ).show() + + def test_kwargs(self): + df = self.data + + @pandas_udf("double", PandasUDFType.GROUPED_AGG) + def weighted_mean(**kwargs): + import numpy as np + + return np.average(kwargs["v"], weights=kwargs["w"]) + + with self.tempView("v"): + df.createOrReplaceTempView("v") + self.spark.udf.register("weighted_mean", weighted_mean) + + for i, aggregated in enumerate( + [ + df.groupby("id").agg(weighted_mean(v=df.v, w=df.w).alias("wm")), + df.groupby("id").agg(weighted_mean(w=df.w, v=df.v).alias("wm")), + self.spark.sql( + "SELECT id, weighted_mean(v => v, w => w) as wm FROM v GROUP BY id" + ), + self.spark.sql( + "SELECT id, weighted_mean(w => w, v => v) as wm FROM v GROUP BY id" + ), + ] + ): + with self.subTest(query_no=i): + assertDataFrameEqual(aggregated, df.groupby("id").agg(mean(df.v).alias("wm"))) + + # negative + with self.assertRaisesRegex( + AnalysisException, + "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + ): + self.spark.sql( + "SELECT id, weighted_mean(v => v, v => w) as wm FROM v GROUP BY id" + ).show() + + with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"): + self.spark.sql( + "SELECT id, weighted_mean(v => v, w) as wm FROM v GROUP BY id" + ).show() + + def test_named_arguments_and_defaults(self): + df = self.data + + @pandas_udf("double", PandasUDFType.GROUPED_AGG) + def biased_sum(v, w=None): + return v.sum() + (w.sum() if w is not None else 100) + + with self.tempView("v"): + df.createOrReplaceTempView("v") + self.spark.udf.register("biased_sum", biased_sum) + + # without "w" + for i, aggregated in enumerate( + [ + df.groupby("id").agg(biased_sum(df.v).alias("s")), + df.groupby("id").agg(biased_sum(v=df.v).alias("s")), + self.spark.sql("SELECT id, biased_sum(v) as s FROM v GROUP BY id"), + self.spark.sql("SELECT id, biased_sum(v => v) as s FROM v GROUP BY id"), + ] + ): + with self.subTest(with_w=False, query_no=i): + assertDataFrameEqual( + aggregated, df.groupby("id").agg((sum(df.v) + lit(100)).alias("s")) + ) + + # with "w" + for i, aggregated in enumerate( + [ + df.groupby("id").agg(biased_sum(df.v, w=df.w).alias("s")), + df.groupby("id").agg(biased_sum(v=df.v, w=df.w).alias("s")), + df.groupby("id").agg(biased_sum(w=df.w, v=df.v).alias("s")), + self.spark.sql("SELECT id, biased_sum(v, w => w) as s FROM v GROUP BY id"), + self.spark.sql("SELECT id, biased_sum(v => v, w => w) as s FROM v GROUP BY id"), + self.spark.sql("SELECT id, biased_sum(w => w, v => v) as s FROM v GROUP BY id"), + ] + ): + with self.subTest(with_w=True, query_no=i): + assertDataFrameEqual( + aggregated, df.groupby("id").agg((sum(df.v) + sum(df.w)).alias("s")) + ) + class GroupedAggPandasUDFTests(GroupedAggPandasUDFTestsMixin, ReusedSQLTestCase): pass diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py index e74e3783b1236..6968c0740943e 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py @@ -18,7 +18,7 @@ import unittest from typing import cast -from pyspark.errors import AnalysisException +from pyspark.errors import AnalysisException, PythonException from pyspark.sql.functions import ( array, explode, @@ -40,7 +40,7 @@ pandas_requirement_message, pyarrow_requirement_message, ) -from pyspark.testing.utils import QuietTest +from pyspark.testing.utils import QuietTest, assertDataFrameEqual if have_pandas: from pandas.testing import assert_frame_equal @@ -107,6 +107,16 @@ def min(v): return min + @property + def pandas_agg_weighted_mean_udf(self): + import numpy as np + + @pandas_udf("double", PandasUDFType.GROUPED_AGG) + def weighted_mean(v, w): + return np.average(v, weights=w) + + return weighted_mean + @property def unbounded_window(self): return ( @@ -394,6 +404,165 @@ def test_bounded_mixed(self): assert_frame_equal(expected1.toPandas(), result1.toPandas()) + def test_named_arguments(self): + df = self.data + weighted_mean = self.pandas_agg_weighted_mean_udf + + for w, bound in [(self.sliding_row_window, True), (self.unbounded_window, False)]: + for i, windowed in enumerate( + [ + df.withColumn("wm", weighted_mean(df.v, w=df.w).over(w)), + df.withColumn("wm", weighted_mean(v=df.v, w=df.w).over(w)), + df.withColumn("wm", weighted_mean(w=df.w, v=df.v).over(w)), + ] + ): + with self.subTest(bound=bound, query_no=i): + assertDataFrameEqual(windowed, df.withColumn("wm", mean(df.v).over(w))) + + with self.tempView("v"): + df.createOrReplaceTempView("v") + self.spark.udf.register("weighted_mean", weighted_mean) + + for w in [ + "ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", + "ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", + ]: + window_spec = f"PARTITION BY id ORDER BY v {w}" + for i, func_call in enumerate( + [ + "weighted_mean(v, w => w)", + "weighted_mean(v => v, w => w)", + "weighted_mean(w => w, v => v)", + ] + ): + with self.subTest(window_spec=window_spec, query_no=i): + assertDataFrameEqual( + self.spark.sql( + f"SELECT id, {func_call} OVER ({window_spec}) as wm FROM v" + ), + self.spark.sql(f"SELECT id, mean(v) OVER ({window_spec}) as wm FROM v"), + ) + + def test_named_arguments_negative(self): + df = self.data + weighted_mean = self.pandas_agg_weighted_mean_udf + + with self.tempView("v"): + df.createOrReplaceTempView("v") + self.spark.udf.register("weighted_mean", weighted_mean) + + base_sql = "SELECT id, {func_call} OVER ({window_spec}) as wm FROM v" + + for w in [ + "ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", + "ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", + ]: + window_spec = f"PARTITION BY id ORDER BY v {w}" + with self.subTest(window_spec=window_spec): + with self.assertRaisesRegex( + AnalysisException, + "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + ): + self.spark.sql( + base_sql.format( + func_call="weighted_mean(v => v, v => w)", window_spec=window_spec + ) + ).show() + + with self.assertRaisesRegex( + AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT" + ): + self.spark.sql( + base_sql.format( + func_call="weighted_mean(v => v, w)", window_spec=window_spec + ) + ).show() + + with self.assertRaisesRegex( + PythonException, r"weighted_mean\(\) got an unexpected keyword argument 'x'" + ): + self.spark.sql( + base_sql.format( + func_call="weighted_mean(v => v, x => w)", window_spec=window_spec + ) + ).show() + + with self.assertRaisesRegex( + PythonException, r"weighted_mean\(\) got multiple values for argument 'v'" + ): + self.spark.sql( + base_sql.format( + func_call="weighted_mean(v, v => w)", window_spec=window_spec + ) + ).show() + + def test_kwargs(self): + df = self.data + + @pandas_udf("double", PandasUDFType.GROUPED_AGG) + def weighted_mean(**kwargs): + import numpy as np + + return np.average(kwargs["v"], weights=kwargs["w"]) + + for w, bound in [(self.sliding_row_window, True), (self.unbounded_window, False)]: + for i, windowed in enumerate( + [ + df.withColumn("wm", weighted_mean(v=df.v, w=df.w).over(w)), + df.withColumn("wm", weighted_mean(w=df.w, v=df.v).over(w)), + ] + ): + with self.subTest(bound=bound, query_no=i): + assertDataFrameEqual(windowed, df.withColumn("wm", mean(df.v).over(w))) + + with self.tempView("v"): + df.createOrReplaceTempView("v") + self.spark.udf.register("weighted_mean", weighted_mean) + + base_sql = "SELECT id, {func_call} OVER ({window_spec}) as wm FROM v" + + for w in [ + "ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", + "ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", + ]: + window_spec = f"PARTITION BY id ORDER BY v {w}" + with self.subTest(window_spec=window_spec): + for i, func_call in enumerate( + [ + "weighted_mean(v => v, w => w)", + "weighted_mean(w => w, v => v)", + ] + ): + with self.subTest(query_no=i): + assertDataFrameEqual( + self.spark.sql( + base_sql.format(func_call=func_call, window_spec=window_spec) + ), + self.spark.sql( + base_sql.format(func_call="mean(v)", window_spec=window_spec) + ), + ) + + # negative + with self.assertRaisesRegex( + AnalysisException, + "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + ): + self.spark.sql( + base_sql.format( + func_call="weighted_mean(v => v, v => w)", window_spec=window_spec + ) + ).show() + + with self.assertRaisesRegex( + AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT" + ): + self.spark.sql( + base_sql.format( + func_call="weighted_mean(v => v, w)", window_spec=window_spec + ) + ).show() + class WindowPandasUDFTests(WindowPandasUDFTestsMixin, ReusedSQLTestCase): pass diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index f72bf28823006..32ea05bd00a7f 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -939,6 +939,11 @@ def test_udf(a, b): ): self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show() + with self.assertRaisesRegex( + PythonException, r"test_udf\(\) got multiple values for argument 'a'" + ): + self.spark.sql("SELECT test_udf(id, a => id * 10) FROM range(2)").show() + def test_kwargs(self): @udf("int") def test_udf(**kwargs): @@ -957,6 +962,16 @@ def test_udf(**kwargs): with self.subTest(query_no=i): assertDataFrameEqual(df, [Row(0), Row(101)]) + # negative + with self.assertRaisesRegex( + AnalysisException, + "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + ): + self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM range(2)").show() + + with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"): + self.spark.sql("SELECT test_udf(a => id, id * 10) FROM range(2)").show() + def test_named_arguments_and_defaults(self): @udf("int") def test_udf(a, b=0): diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index a7545c332e6a0..95e46ba433cb9 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -1848,6 +1848,11 @@ def eval(self, a, b): ): self.spark.sql("SELECT * FROM test_udtf(c => 'x')").show() + with self.assertRaisesRegex( + PythonException, r"eval\(\) got multiple values for argument 'a'" + ): + self.spark.sql("SELECT * FROM test_udtf(10, a => 100)").show() + def test_udtf_with_kwargs(self): @udtf(returnType="a: int, b: string") class TestUDTF: @@ -1867,6 +1872,16 @@ def eval(self, **kwargs): with self.subTest(query_no=i): assertDataFrameEqual(df, [Row(a=10, b="x")]) + # negative + with self.assertRaisesRegex( + AnalysisException, + "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + ): + self.spark.sql("SELECT * FROM test_udtf(a => 10, a => 100)").show() + + with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"): + self.spark.sql("SELECT * FROM test_udtf(a => 10, 'x')").show() + def test_udtf_with_analyze_kwargs(self): @udtf class TestUDTF: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 19c8c9c897b8e..d95a5c4672f86 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -452,13 +452,13 @@ def verify_element(result): def wrap_grouped_agg_pandas_udf(f, return_type): arrow_return_type = to_arrow_type(return_type) - def wrapped(*series): + def wrapped(*args, **kwargs): import pandas as pd - result = f(*series) + result = f(*args, **kwargs) return pd.Series([result]) - return lambda *a: (wrapped(*a), arrow_return_type) + return lambda *a, **kw: (wrapped(*a, **kw), arrow_return_type) def wrap_window_agg_pandas_udf(f, return_type, runner_conf, udf_index): @@ -484,19 +484,19 @@ def wrap_unbounded_window_agg_pandas_udf(f, return_type): # the scalar value. arrow_return_type = to_arrow_type(return_type) - def wrapped(*series): + def wrapped(*args, **kwargs): import pandas as pd - result = f(*series) - return pd.Series([result]).repeat(len(series[0])) + result = f(*args, **kwargs) + return pd.Series([result]).repeat(len((list(args) + list(kwargs.values()))[0])) - return lambda *a: (wrapped(*a), arrow_return_type) + return lambda *a, **kw: (wrapped(*a, **kw), arrow_return_type) def wrap_bounded_window_agg_pandas_udf(f, return_type): arrow_return_type = to_arrow_type(return_type) - def wrapped(begin_index, end_index, *series): + def wrapped(begin_index, end_index, *args, **kwargs): import pandas as pd result = [] @@ -521,11 +521,12 @@ def wrapped(begin_index, end_index, *series): # Note: Calling reset_index on the slices will increase the cost # of creating slices by about 100%. Therefore, for performance # reasons we don't do it here. - series_slices = [s.iloc[begin_array[i] : end_array[i]] for s in series] - result.append(f(*series_slices)) + args_slices = [s.iloc[begin_array[i] : end_array[i]] for s in args] + kwargs_slices = {k: s.iloc[begin_array[i] : end_array[i]] for k, s in kwargs.items()} + result.append(f(*args_slices, **kwargs_slices)) return pd.Series(result) - return lambda *a: (wrapped(*a), arrow_return_type) + return lambda *a, **kw: (wrapped(*a, **kw), arrow_return_type) def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): @@ -535,6 +536,8 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): PythonEvalType.SQL_BATCHED_UDF, PythonEvalType.SQL_ARROW_BATCHED_UDF, PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, + PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, # The below doesn't support named argument, but shares the same protocol. PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, ): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9a6d9c8b735be..b93f87e77b97f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3003,6 +3003,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // we need to make sure that col1 to col5 are all projected from the child of the Window // operator. val extractedExprMap = mutable.LinkedHashMap.empty[Expression, NamedExpression] + def getOrExtract(key: Expression, value: Expression): Expression = { + extractedExprMap.getOrElseUpdate(key.canonicalized, + Alias(value, s"_w${extractedExprMap.size}")()).toAttribute + } def extractExpr(expr: Expression): Expression = expr match { case ne: NamedExpression => // If a named expression is not in regularExpressions, add it to @@ -3016,11 +3020,14 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ne case e: Expression if e.foldable => e // No need to create an attribute reference if it will be evaluated as a Literal. + case e: NamedArgumentExpression => + // For NamedArgumentExpression, we extract the value and replace it with + // an AttributeReference (with an internal column name, e.g. "_w0"). + NamedArgumentExpression(e.key, getOrExtract(e, e.value)) case e: Expression => // For other expressions, we extract it and replace it with an AttributeReference (with // an internal column name, e.g. "_w0"). - extractedExprMap.getOrElseUpdate(e.canonicalized, - Alias(e, s"_w${extractedExprMap.size}")()).toAttribute + getOrExtract(e, e) } // Now, we extract regular expressions from expressionsWithWindowFunctions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 73560a596ca58..7e349b665f352 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.aggregate.UpdatingSessionsIterator +import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.util.Utils @@ -109,14 +110,20 @@ case class AggregateInPandasExec( // Also eliminate duplicate UDF inputs. val allInputs = new ArrayBuffer[Expression] val dataTypes = new ArrayBuffer[DataType] - val argOffsets = inputs.map { input => + val argMetas = inputs.map { input => input.map { e => - if (allInputs.exists(_.semanticEquals(e))) { - allInputs.indexWhere(_.semanticEquals(e)) + val (key, value) = e match { + case NamedArgumentExpression(key, value) => + (Some(key), value) + case _ => + (None, e) + } + if (allInputs.exists(_.semanticEquals(value))) { + ArgumentMetadata(allInputs.indexWhere(_.semanticEquals(value)), key) } else { - allInputs += e - dataTypes += e.dataType - allInputs.length - 1 + allInputs += value + dataTypes += value.dataType + ArgumentMetadata(allInputs.length - 1, key) } }.toArray }.toArray @@ -164,10 +171,10 @@ case class AggregateInPandasExec( rows } - val columnarBatchIter = new ArrowPythonRunner( + val columnarBatchIter = new ArrowPythonWithNamedArgumentRunner( pyFuncs, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, - argOffsets, + argMetas, aggInputSchema, sessionLocalTimeZone, largeVarTypes, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index f576637aa25b7..2fcc428407ecc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -52,7 +52,8 @@ case class UserDefinedPythonFunction( def builder(e: Seq[Expression]): Expression = { if (pythonEvalType == PythonEvalType.SQL_BATCHED_UDF || pythonEvalType ==PythonEvalType.SQL_ARROW_BATCHED_UDF - || pythonEvalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF) { + || pythonEvalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF + || pythonEvalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) { /* * Check if the named arguments: * - don't have duplicated names diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala index a32d892622b4c..cf9f8c22ea082 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala @@ -25,11 +25,12 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.{JobArtifactSet, PartitionEvaluator, PartitionEvaluatorFactory, SparkEnv, TaskContext} import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, EmptyRow, Expression, JoinedRow, NamedExpression, PythonFuncExpression, PythonUDAF, SortOrder, SpecificInternalRow, UnsafeProjection, UnsafeRow, WindowExpression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, EmptyRow, Expression, JoinedRow, NamedArgumentExpression, NamedExpression, PythonFuncExpression, PythonUDAF, SortOrder, SpecificInternalRow, UnsafeProjection, UnsafeRow, WindowExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata import org.apache.spark.sql.execution.window.{SlidingWindowFunctionFrame, UnboundedFollowingWindowFunctionFrame, UnboundedPrecedingWindowFunctionFrame, UnboundedWindowFunctionFrame, WindowEvaluatorFactoryBase, WindowFunctionFrame} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructType} @@ -170,14 +171,20 @@ class WindowInPandasEvaluatorFactory( // handles UDF inputs. private val dataInputs = new ArrayBuffer[Expression] private val dataInputTypes = new ArrayBuffer[DataType] - private val argOffsets = inputs.map { input => + private val argMetas = inputs.map { input => input.map { e => - if (dataInputs.exists(_.semanticEquals(e))) { - dataInputs.indexWhere(_.semanticEquals(e)) + val (key, value) = e match { + case NamedArgumentExpression(key, value) => + (Some(key), value) + case _ => + (None, e) + } + if (dataInputs.exists(_.semanticEquals(value))) { + ArgumentMetadata(dataInputs.indexWhere(_.semanticEquals(value)), key) } else { - dataInputs += e - dataInputTypes += e.dataType - dataInputs.length - 1 + dataInputs += value + dataInputTypes += value.dataType + ArgumentMetadata(dataInputs.length - 1, key) } }.toArray }.toArray @@ -206,11 +213,15 @@ class WindowInPandasEvaluatorFactory( pyFuncs.indices.foreach { exprIndex => val frameIndex = expressionIndexToFrameIndex(exprIndex) if (isBounded(frameIndex)) { - argOffsets(exprIndex) = - Array(lowerBoundIndex(frameIndex), upperBoundIndex(frameIndex)) ++ - argOffsets(exprIndex).map(_ + windowBoundsInput.length) + argMetas(exprIndex) = + Array( + ArgumentMetadata(lowerBoundIndex(frameIndex), None), + ArgumentMetadata(upperBoundIndex(frameIndex), None)) ++ + argMetas(exprIndex).map( + meta => ArgumentMetadata(meta.offset + windowBoundsInput.length, meta.name)) } else { - argOffsets(exprIndex) = argOffsets(exprIndex).map(_ + windowBoundsInput.length) + argMetas(exprIndex) = argMetas(exprIndex).map( + meta => ArgumentMetadata(meta.offset + windowBoundsInput.length, meta.name)) } } @@ -346,10 +357,10 @@ class WindowInPandasEvaluatorFactory( } } - val windowFunctionResult = new ArrowPythonRunner( + val windowFunctionResult = new ArrowPythonWithNamedArgumentRunner( pyFuncs, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, - argOffsets, + argMetas, pythonInputSchema, sessionLocalTimeZone, largeVarTypes, From e4ebb372fa185c8cac60b339b1c9df06406ca08f Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Fri, 1 Sep 2023 17:20:58 -0700 Subject: [PATCH 20/35] [SPARK-44901][SQL] Add API in Python UDTF 'analyze' method to return partitioning/ordering expressions ### What changes were proposed in this pull request? This PR adds an API in the Python UDTF 'analyze' method to require partitioning/ordering properties from the input relation. Catalyst then performs necessary repartitioning and/or sorting as needed to fulfill the requested properties. For example, the following property would request for Catalyst to behave as if the UDTF call included `PARTITION BY partition_col ORDER BY input`: ``` from pyspark.sql.functions import AnalyzeResult, OrderingColumn, PartitioningColumn from pyspark.sql.types import IntegerType, StructType udtf class MyUDTF: staticmethod def analyze(self): return AnalyzeResult( schema=StructType() .add("partition_col", IntegerType()) .add("count", IntegerType()) .add("total", IntegerType()) .add("last", IntegerType()), partition_by=[ PartitioningColumn("partition_col") ], order_by=[ OrderingColumn("input") ]) ... ``` Or, the following property would request for Catalyst to behave as if the UDTF call included `WITH SINGLE PARTITION`: ``` from pyspark.sql.functions import AnalyzeResult from pyspark.sql.types import IntegerType, StructType udtf class MyUDTF: staticmethod def analyze(self): return AnalyzeResult( schema=StructType() .add("partition_col", IntegerType()) .add("count", IntegerType()) .add("total", IntegerType()) .add("last", IntegerType()), with_single_partition=True) ... ``` ### Why are the changes needed? This gives Python UDTF authors the ability to write table functions that can assume constraints about which rows are consumed by which instances of the UDTF class. ### Does this PR introduce _any_ user-facing change? Yes, see above. ### How was this patch tested? This PR adds unit test coverage in Scala and Python. ### Was this patch authored or co-authored using generative AI tooling? No Closes #42595 from dtenedor/anlayze-result. Authored-by: Daniel Tenedorio Signed-off-by: Takuya UESHIN --- .../main/resources/error/error-classes.json | 12 + docs/sql-error-conditions.md | 12 + python/pyspark/sql/functions.py | 1 + python/pyspark/sql/tests/test_udtf.py | 142 +++++- python/pyspark/sql/udtf.py | 44 +- python/pyspark/sql/worker/analyze_udtf.py | 18 + .../sql/catalyst/analysis/Analyzer.scala | 53 ++- .../sql/catalyst/expressions/PythonUDF.scala | 100 +++- .../sql/catalyst/expressions/generators.scala | 2 +- .../sql/errors/QueryCompilationErrors.scala | 22 + .../python/UserDefinedPythonFunction.scala | 49 +- .../analyzer-results/udtf/udtf.sql.out | 397 ++++++++++++++++ .../resources/sql-tests/inputs/udtf/udtf.sql | 91 ++++ .../sql-tests/results/udtf/udtf.sql.out | 448 ++++++++++++++++++ .../spark/sql/IntegratedUDFTestUtils.scala | 272 ++++++++++- .../apache/spark/sql/SQLQueryTestSuite.scala | 40 +- .../execution/python/PythonUDTFSuite.scala | 68 ++- 17 files changed, 1716 insertions(+), 55 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 87b9da7638b2a..c5a63dd68b9e0 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -2698,6 +2698,18 @@ "Failed to analyze the Python user defined table function: " ] }, + "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INCOMPATIBLE_WITH_CALL" : { + "message" : [ + "Failed to evaluate the table function because its table metadata , but the function call ." + ], + "sqlState" : "22023" + }, + "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INVALID" : { + "message" : [ + "Failed to evaluate the table function because its table metadata was invalid; ." + ], + "sqlState" : "22023" + }, "TABLE_VALUED_FUNCTION_TOO_MANY_TABLE_ARGUMENTS" : { "message" : [ "There are too many table arguments for table-valued function. It allows one table argument, but got: . If you want to allow it, please set \"spark.sql.allowMultipleTableArguments.enabled\" to \"true\"" diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index 33072f6c44066..e25ef384a75cb 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -1764,6 +1764,18 @@ SQLSTATE: none assigned Failed to analyze the Python user defined table function: `` +### TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INCOMPATIBLE_WITH_CALL + +[SQLSTATE: 22023](sql-error-conditions-sqlstates.html#class-22-data-exception) + +Failed to evaluate the table function `` because its table metadata ``, but the function call ``. + +### TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INVALID + +[SQLSTATE: 22023](sql-error-conditions-sqlstates.html#class-22-data-exception) + +Failed to evaluate the table function `` because its table metadata was invalid; ``. + ### TABLE_VALUED_FUNCTION_TOO_MANY_TABLE_ARGUMENTS SQLSTATE: none assigned diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 0447bf0e19c8d..fb02cb0cc98b4 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -50,6 +50,7 @@ # Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409 from pyspark.sql.udf import UserDefinedFunction, _create_py_udf # noqa: F401 from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult # noqa: F401 +from pyspark.sql.udtf import OrderingColumn, PartitioningColumn # noqa: F401 from pyspark.sql.udtf import UserDefinedTableFunction, _create_py_udtf # Keep pandas_udf and PandasUDFType import for backwards compatible import; moved in SPARK-28264 diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 95e46ba433cb9..c5f8b7693c26d 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -41,6 +41,8 @@ udtf, AnalyzeArgument, AnalyzeResult, + OrderingColumn, + PartitioningColumn, ) from pyspark.sql.types import ( ArrayType, @@ -2002,7 +2004,7 @@ def terminate(self): # This is a basic example. func = udtf(TestUDTF, returnType="partition_col: int, total: int") self.spark.udtf.register("test_udtf", func) - self.assertEqual( + assertDataFrameEqual( self.spark.sql( """ WITH t AS ( @@ -2023,7 +2025,7 @@ def terminate(self): ("123", "456", 123, 456), ("123", "NULL", None, 123), ): - self.assertEqual( + assertDataFrameEqual( self.spark.sql( f""" WITH t AS ( @@ -2045,7 +2047,7 @@ def terminate(self): # Combine a lateral join with a TABLE argument with PARTITION BY . func = udtf(TestUDTF, returnType="partition_col: int, total: int") self.spark.udtf.register("test_udtf", func) - self.assertEqual( + assertDataFrameEqual( self.spark.sql( """ WITH t AS ( @@ -2090,7 +2092,7 @@ def terminate(self): ("input DESC", 1), ("input - 1 DESC", 1), ): - self.assertEqual( + assertDataFrameEqual( self.spark.sql( f""" WITH t AS ( @@ -2130,7 +2132,7 @@ def terminate(self): func = udtf(TestUDTF, returnType="count: int, total: int, last: int") self.spark.udtf.register("test_udtf", func) - self.assertEqual( + assertDataFrameEqual( self.spark.sql( """ WITH t AS ( @@ -2143,7 +2145,135 @@ def terminate(self): ORDER BY 1, 2 """ ).collect(), - [Row(count=40, total=60, last=2)], + [ + Row(count=40, total=60, last=2), + ], + ) + + def test_udtf_with_table_argument_with_single_partition_from_analyze(self): + @udtf + class TestUDTF: + def __init__(self): + self._count = 0 + self._sum = 0 + self._last = None + + @staticmethod + def analyze(self): + return AnalyzeResult( + schema=StructType() + .add("count", IntegerType()) + .add("total", IntegerType()) + .add("last", IntegerType()), + with_single_partition=True, + order_by=[OrderingColumn("input"), OrderingColumn("partition_col")], + ) + + def eval(self, row: Row): + # Make sure that the rows arrive in the expected order. + if self._last is not None and self._last > row["input"]: + raise Exception( + f"self._last was {self._last} but the row value was {row['input']}" + ) + self._count += 1 + self._last = row["input"] + self._sum += row["input"] + + def terminate(self): + yield self._count, self._sum, self._last + + self.spark.udtf.register("test_udtf", TestUDTF) + + assertDataFrameEqual( + self.spark.sql( + """ + WITH t AS ( + SELECT id AS partition_col, 1 AS input FROM range(1, 21) + UNION ALL + SELECT id AS partition_col, 2 AS input FROM range(1, 21) + ) + SELECT count, total, last + FROM test_udtf(TABLE(t)) + ORDER BY 1, 2 + """ + ).collect(), + [ + Row(count=40, total=60, last=2), + ], + ) + + def test_udtf_with_table_argument_with_partition_by_and_order_by_from_analyze(self): + @udtf + class TestUDTF: + def __init__(self): + self._partition_col = None + self._count = 0 + self._sum = 0 + self._last = None + + @staticmethod + def analyze(self): + return AnalyzeResult( + schema=StructType() + .add("partition_col", IntegerType()) + .add("count", IntegerType()) + .add("total", IntegerType()) + .add("last", IntegerType()), + partition_by=[PartitioningColumn("partition_col")], + order_by=[ + OrderingColumn(name="input", ascending=True, overrideNullsFirst=False) + ], + ) + + def eval(self, row: Row): + # Make sure that all values of the partitioning column are the same + # for each row consumed by this method for this instance of the class. + if self._partition_col is not None and self._partition_col != row["partition_col"]: + raise Exception( + f"self._partition_col was {self._partition_col} but the row " + + f"value was {row['partition_col']}" + ) + # Make sure that the rows arrive in the expected order. + if ( + self._last is not None + and row["input"] is not None + and self._last > row["input"] + ): + raise Exception( + f"self._last was {self._last} but the row value was {row['input']}" + ) + self._partition_col = row["partition_col"] + self._count += 1 + self._last = row["input"] + if row["input"] is not None: + self._sum += row["input"] + + def terminate(self): + yield self._partition_col, self._count, self._sum, self._last + + self.spark.udtf.register("test_udtf", TestUDTF) + + assertDataFrameEqual( + self.spark.sql( + """ + WITH t AS ( + SELECT id AS partition_col, 1 AS input FROM range(1, 21) + UNION ALL + SELECT id AS partition_col, 2 AS input FROM range(1, 21) + UNION ALL + SELECT 42 AS partition_col, NULL AS input + UNION ALL + SELECT 42 AS partition_col, 1 AS input + UNION ALL + SELECT 42 AS partition_col, 2 AS input + ) + SELECT partition_col, count, total, last + FROM test_udtf(TABLE(t)) + ORDER BY 1, 2 + """ + ).collect(), + [Row(partition_col=x, count=2, total=3, last=2) for x in range(1, 21)] + + [Row(partition_col=42, count=3, total=3, last=None)], ) diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py index 72bba3d9a2c48..ba4bac2ffdfa9 100644 --- a/python/pyspark/sql/udtf.py +++ b/python/pyspark/sql/udtf.py @@ -18,11 +18,11 @@ User-defined table function related classes and functions """ import pickle -from dataclasses import dataclass +from dataclasses import dataclass, field import inspect import sys import warnings -from typing import Any, Type, TYPE_CHECKING, Optional, Union +from typing import Any, Type, TYPE_CHECKING, Optional, Sequence, Union from py4j.java_gateway import JavaObject @@ -61,6 +61,30 @@ class AnalyzeArgument: is_table: bool +@dataclass(frozen=True) +class PartitioningColumn: + """ + Represents a UDTF column for purposes of returning metadata from the 'analyze' method. + """ + + name: str + + +@dataclass(frozen=True) +class OrderingColumn: + """ + Represents a single ordering column name for purposes of returning metadata from the 'analyze' + method. + """ + + name: str + ascending: bool = True + # If this is None, use the default behavior to sort NULL values first when sorting in ascending + # order, or last when sorting in descending order. Otherwise, if this is True or False, override + # the default behavior accordingly. + overrideNullsFirst: Optional[bool] = None + + @dataclass(frozen=True) class AnalyzeResult: """ @@ -70,9 +94,25 @@ class AnalyzeResult: ---------- schema : :class:`StructType` The schema that the Python UDTF will return. + with_single_partition : bool + If true, the UDTF is specifying for Catalyst to repartition all rows of the input TABLE + argument to one collection for consumption by exactly one instance of the correpsonding + UDTF class. + partition_by : Sequence[PartitioningColumn] + If non-empty, this is a sequence of columns that the UDTF is specifying for Catalyst to + partition the input TABLE argument by. In this case, calls to the UDTF may not include any + explicit PARTITION BY clause, in which case Catalyst will return an error. This option is + mutually exclusive with 'with_single_partition'. + order_by: Sequence[OrderingColumn] + If non-empty, this is a sequence of columns that the UDTF is specifying for Catalyst to + sort the input TABLE argument by. Note that the 'partition_by' list must also be non-empty + in this case. """ schema: StructType + with_single_partition: bool = False + partition_by: Sequence[PartitioningColumn] = field(default_factory=tuple) + order_by: Sequence[OrderingColumn] = field(default_factory=tuple) def _create_udtf( diff --git a/python/pyspark/sql/worker/analyze_udtf.py b/python/pyspark/sql/worker/analyze_udtf.py index 7ba0789fa7b71..29665b586a36e 100644 --- a/python/pyspark/sql/worker/analyze_udtf.py +++ b/python/pyspark/sql/worker/analyze_udtf.py @@ -127,6 +127,24 @@ def main(infile: IO, outfile: IO) -> None: # Return the analyzed schema. write_with_length(result.schema.json().encode("utf-8"), outfile) + # Return whether the "with single partition" property is requested. + write_int(1 if result.with_single_partition else 0, outfile) + # Return the list of partitioning columns, if any. + write_int(len(result.partition_by), outfile) + for partitioning_col in result.partition_by: + write_with_length(partitioning_col.name.encode("utf-8"), outfile) + # Return the requested input table ordering, if any. + write_int(len(result.order_by), outfile) + for ordering_col in result.order_by: + write_with_length(ordering_col.name.encode("utf-8"), outfile) + write_int(1 if ordering_col.ascending else 0, outfile) + if ordering_col.overrideNullsFirst is None: + write_int(0, outfile) + elif ordering_col.overrideNullsFirst: + write_int(1, outfile) + else: + write_int(2, outfile) + except BaseException as e: try: exc_info = None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b93f87e77b97f..a8c99075cdb80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2081,7 +2081,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => withPosition(u) { try { - val resolvedFunc = resolveBuiltinOrTempTableFunction(u.name, u.functionArgs).getOrElse { + val resolvedTvf = resolveBuiltinOrTempTableFunction(u.name, u.functionArgs).getOrElse { val CatalogAndIdentifier(catalog, ident) = expandIdentifier(u.name) if (CatalogV2Util.isSessionCatalog(catalog)) { v1SessionCatalog.resolvePersistentTableFunction( @@ -2091,7 +2091,18 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor catalog, "table-valued functions") } } - + // Resolve Python UDTF calls if needed. + val resolvedFunc = resolvedTvf match { + case g @ Generate(u: UnresolvedPolymorphicPythonUDTF, _, _, _, _, _) => + val analyzeResult: PythonUDTFAnalyzeResult = + u.resolveElementMetadata(u.func, u.children) + g.copy(generator = + PythonUDTF(u.name, u.func, analyzeResult.schema, u.children, + u.evalType, u.udfDeterministic, u.resultId, u.pythonUDTFPartitionColumnIndexes, + analyzeResult = Some(analyzeResult))) + case other => + other + } val tableArgs = mutable.ArrayBuffer.empty[LogicalPlan] val functionTableSubqueryArgs = mutable.ArrayBuffer.empty[FunctionTableSubqueryArgumentExpression] @@ -2099,15 +2110,35 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor _.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION), ruleId) { case t: FunctionTableSubqueryArgumentExpression => val alias = SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}") - resolvedFunc match { - case Generate(_: PythonUDTF, _, _, _, _, _) => - case _ => - assert(!t.hasRepartitioning, - "Cannot evaluate the table-valued function call because it included the " + - "PARTITION BY clause, but only Python table functions support this clause") + val ( + pythonUDTFName: String, + pythonUDTFAnalyzeResult: Option[PythonUDTFAnalyzeResult]) = + resolvedFunc match { + case Generate(p: PythonUDTF, _, _, _, _, _) => + (p.name, + p.analyzeResult) + case _ => + assert(!t.hasRepartitioning, + "Cannot evaluate the table-valued function call because it included the " + + "PARTITION BY clause, but only Python table functions support this " + + "clause") + ("", None) + } + // Check if this is a call to a Python user-defined table function whose polymorphic + // 'analyze' method returned metadata indicated requested partitioning and/or + // ordering properties of the input relation. In that event, make sure that the UDTF + // call did not include any explicit PARTITION BY and/or ORDER BY clauses for the + // corresponding TABLE argument, and then update the TABLE argument representation + // to apply the requested partitioning and/or ordering. + pythonUDTFAnalyzeResult.map { analyzeResult => + val newTableArgument: FunctionTableSubqueryArgumentExpression = + analyzeResult.applyToTableArgument(pythonUDTFName, t) + tableArgs.append(SubqueryAlias(alias, newTableArgument.evaluable)) + functionTableSubqueryArgs.append(newTableArgument) + }.getOrElse { + tableArgs.append(SubqueryAlias(alias, t.evaluable)) + functionTableSubqueryArgs.append(t) } - tableArgs.append(SubqueryAlias(alias, t.evaluable)) - functionTableSubqueryArgs.append(t) UnresolvedAttribute(Seq(alias, "c")) } if (tableArgs.nonEmpty) { @@ -2219,7 +2250,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } case u: UnresolvedPolymorphicPythonUDTF => withPosition(u) { - val elementSchema = u.resolveElementSchema(u.func, u.children) + val elementSchema = u.resolveElementMetadata(u.func, u.children).schema PythonUDTF(u.name, u.func, elementSchema, u.children, u.evalType, u.udfDeterministic, u.resultId, u.pythonUDTFPartitionColumnIndexes) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index 787763a5bb4f7..a615348bc6ea8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.trees.TreePattern.{PYTHON_UDF, TreePattern} import org.apache.spark.sql.catalyst.util.toPrettySQL -import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.types.{DataType, StructType} /** @@ -154,8 +154,23 @@ abstract class UnevaluableGenerator extends Generator { } /** - * A serialized version of a Python table-valued function. This is a special expression, + * A serialized version of a Python table-valued function call. This is a special expression, * which needs a dedicated physical operator to execute it. + * @param name name of the Python UDTF being called + * @param func string contents of the Python code in the UDTF, along with other environment state + * @param elementSchema result schema of the function call + * @param children input arguments to the UDTF call; for scalar arguments these are the expressions + * themeselves, and for TABLE arguments, these are instances of + * [[FunctionTableSubqueryArgumentExpression]] + * @param evalType identifies whether this is a scalar or aggregate or table function, using an + * instance of the [[PythonEvalType]] enumeration + * @param udfDeterministic true if this function is deterministic wherein it returns the same result + * rows for every call with the same input arguments + * @param resultId unique expression ID for this function invocation + * @param pythonUDTFPartitionColumnIndexes holds the indexes of the TABLE argument to the Python + * UDTF call, if applicable + * @param analyzeResult holds the result of the polymorphic Python UDTF 'analze' method, if the UDTF + * defined one */ case class PythonUDTF( name: String, @@ -165,7 +180,8 @@ case class PythonUDTF( evalType: Int, udfDeterministic: Boolean, resultId: ExprId = NamedExpression.newExprId, - pythonUDTFPartitionColumnIndexes: Option[PythonUDTFPartitionColumnIndexes] = None) + pythonUDTFPartitionColumnIndexes: Option[PythonUDTFPartitionColumnIndexes] = None, + analyzeResult: Option[PythonUDTFAnalyzeResult] = None) extends UnevaluableGenerator with PythonFuncExpression { override lazy val canonicalized: Expression = { @@ -193,7 +209,7 @@ case class UnresolvedPolymorphicPythonUDTF( children: Seq[Expression], evalType: Int, udfDeterministic: Boolean, - resolveElementSchema: (PythonFunction, Seq[Expression]) => StructType, + resolveElementMetadata: (PythonFunction, Seq[Expression]) => PythonUDTFAnalyzeResult, resultId: ExprId = NamedExpression.newExprId, pythonUDTFPartitionColumnIndexes: Option[PythonUDTFPartitionColumnIndexes] = None) extends UnevaluableGenerator with PythonFuncExpression { @@ -207,6 +223,82 @@ case class UnresolvedPolymorphicPythonUDTF( copy(children = newChildren) } +/** + * Represents the result of invoking the polymorphic 'analyze' method on a Python user-defined table + * function. This returns the table function's output schema in addition to other optional metadata. + * @param schema result schema of this particular function call in response to the particular + * arguments provided, including the types of any provided scalar arguments (and + * their values, in the case of literals) as well as the names and types of columns of + * the provided TABLE argument (if any) + * @param withSinglePartition true if the 'analyze' method explicitly indicated that the UDTF call + * should consume all rows of the input TABLE argument in a single + * instance of the UDTF class, in which case Catalyst will invoke a + * repartitioning to a separate stage with a single worker for this + * purpose + * @param partitionByExpressions if non-empty, this contains the list of column names that the + * 'analyze' method explicitly indicated that the UDTF call should + * partition the input table by, wherein all rows corresponding to + * each unique combination of values of the partitioning columns are + * consumed by exactly one unique instance of the UDTF class + * @param orderByExpressions if non-empty, this contains the list of ordering items that the + * 'analyze' method explicitly indicated that the UDTF call should consume + * the input table rows by + */ +case class PythonUDTFAnalyzeResult( + schema: StructType, + withSinglePartition: Boolean, + partitionByExpressions: Seq[Expression], + orderByExpressions: Seq[SortOrder]) { + /** + * Applies the requested properties from this analysis result to the target TABLE argument + * expression of a UDTF call, throwing an error if any properties of the UDTF call are + * incompatible. + */ + def applyToTableArgument( + pythonUDTFName: String, + t: FunctionTableSubqueryArgumentExpression): FunctionTableSubqueryArgumentExpression = { + if (withSinglePartition && partitionByExpressions.nonEmpty) { + throw QueryCompilationErrors.tableValuedFunctionRequiredMetadataInvalid( + functionName = pythonUDTFName, + reason = "the 'with_single_partition' field cannot be assigned to true " + + "if the 'partition_by' list is non-empty") + } + if (orderByExpressions.nonEmpty && !withSinglePartition && partitionByExpressions.isEmpty) { + throw QueryCompilationErrors.tableValuedFunctionRequiredMetadataInvalid( + functionName = pythonUDTFName, + reason = "the 'order_by' field cannot be non-empty unless the " + + "'with_single_partition' field is set to true or the 'partition_by' list " + + "is non-empty") + } + if ((withSinglePartition || partitionByExpressions.nonEmpty) && t.hasRepartitioning) { + throw QueryCompilationErrors + .tableValuedFunctionRequiredMetadataIncompatibleWithCall( + functionName = pythonUDTFName, + requestedMetadata = + "specified its own required partitioning of the input table", + invalidFunctionCallProperty = + "specified the WITH SINGLE PARTITION or PARTITION BY clause; " + + "please remove these clauses and retry the query again.") + } + var newWithSinglePartition = t.withSinglePartition + var newPartitionByExpressions = t.partitionByExpressions + var newOrderByExpressions = t.orderByExpressions + if (withSinglePartition) { + newWithSinglePartition = true + } + if (partitionByExpressions.nonEmpty) { + newPartitionByExpressions = partitionByExpressions + } + if (orderByExpressions.nonEmpty) { + newOrderByExpressions = orderByExpressions + } + t.copy( + withSinglePartition = newWithSinglePartition, + partitionByExpressions = newPartitionByExpressions, + orderByExpressions = newOrderByExpressions) + } +} + /** * A place holder used when printing expressions without debugging information such as the * result id. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index afaaf07d2726b..ae144d067755d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -58,7 +58,7 @@ trait Generator extends Expression { override def nullable: Boolean = false - final override val nodePatterns: Seq[TreePattern] = Seq(GENERATOR) + protected override val nodePatterns: Seq[TreePattern] = Seq(GENERATOR) /** * The output element schema. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index ca101e79d9211..989c24430557a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -3764,4 +3764,26 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat "supported" -> "constant expressions"), cause = cause) } + + def tableValuedFunctionRequiredMetadataIncompatibleWithCall( + functionName: String, + requestedMetadata: String, + invalidFunctionCallProperty: String): Throwable = { + new AnalysisException( + errorClass = "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INCOMPATIBLE_WITH_CALL", + messageParameters = Map( + "functionName" -> functionName, + "requestedMetadata" -> requestedMetadata, + "invalidFunctionCallProperty" -> invalidFunctionCallProperty)) + } + + def tableValuedFunctionRequiredMetadataInvalid( + functionName: String, + reason: String): Throwable = { + new AnalysisException( + errorClass = "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INVALID", + messageParameters = Map( + "functionName" -> functionName, + "reason" -> reason)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index 2fcc428407ecc..2beefedc9467e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -24,6 +24,7 @@ import java.nio.charset.StandardCharsets import java.util.HashMap import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import net.razorvine.pickle.Pickler @@ -32,7 +33,8 @@ import org.apache.spark.api.python.{PythonEvalType, PythonFunction, PythonWorker import org.apache.spark.internal.config.BUFFER_SIZE import org.apache.spark.internal.config.Python._ import org.apache.spark.sql.{Column, DataFrame, Dataset, SparkSession} -import org.apache.spark.sql.catalyst.expressions.{Expression, FunctionTableSubqueryArgumentExpression, NamedArgumentExpression, PythonUDAF, PythonUDF, PythonUDTF, UnresolvedPolymorphicPythonUDTF} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, FunctionTableSubqueryArgumentExpression, NamedArgumentExpression, NullsFirst, NullsLast, PythonUDAF, PythonUDF, PythonUDTF, PythonUDTFAnalyzeResult, SortOrder, UnresolvedPolymorphicPythonUDTF} import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, NamedParametersSupport, OneRowRelation} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -145,7 +147,7 @@ case class UserDefinedPythonTableFunction( children = exprs, evalType = pythonEvalType, udfDeterministic = udfDeterministic, - resolveElementSchema = UserDefinedPythonTableFunction.analyzeInPython(_, _, tableArgs)) + resolveElementMetadata = UserDefinedPythonTableFunction.analyzeInPython(_, _, tableArgs)) } Generate( udtf, @@ -191,7 +193,9 @@ object UserDefinedPythonTableFunction { * will be thrown when an exception is raised in Python. */ def analyzeInPython( - func: PythonFunction, exprs: Seq[Expression], tableArgs: Seq[Boolean]): StructType = { + func: PythonFunction, + exprs: Seq[Expression], + tableArgs: Seq[Boolean]): PythonUDTFAnalyzeResult = { val env = SparkEnv.get val bufferSize: Int = env.conf.get(BUFFER_SIZE) val authSocketTimeout = env.conf.get(PYTHON_AUTH_SOCKET_TIMEOUT) @@ -272,7 +276,7 @@ object UserDefinedPythonTableFunction { val dataIn = new DataInputStream(new BufferedInputStream( new WorkerInputStream(worker, bufferStream.toByteBuffer), bufferSize)) - // Receive the schema + // Receive the schema. val schema = dataIn.readInt() match { case length if length >= 0 => val obj = new Array[Byte](length) @@ -286,6 +290,37 @@ object UserDefinedPythonTableFunction { val msg = new String(obj, StandardCharsets.UTF_8) throw QueryCompilationErrors.tableValuedFunctionFailedToAnalyseInPythonError(msg) } + // Receive whether the "with single partition" property is requested. + val withSinglePartition = dataIn.readInt() == 1 + // Receive the list of requested partitioning columns, if any. + val partitionByColumns = ArrayBuffer.empty[Expression] + val numPartitionByColumns = dataIn.readInt() + for (_ <- 0 until numPartitionByColumns) { + val length = dataIn.readInt() + val obj = new Array[Byte](length) + dataIn.readFully(obj) + val columnName = new String(obj, StandardCharsets.UTF_8) + partitionByColumns.append(UnresolvedAttribute(columnName)) + } + // Receive the list of requested ordering columns, if any. + val orderBy = ArrayBuffer.empty[SortOrder] + val numOrderByItems = dataIn.readInt() + for (_ <- 0 until numOrderByItems) { + val length = dataIn.readInt() + val obj = new Array[Byte](length) + dataIn.readFully(obj) + val columnName = new String(obj, StandardCharsets.UTF_8) + val direction = if (dataIn.readInt() == 1) Ascending else Descending + val overrideNullsFirst = dataIn.readInt() + overrideNullsFirst match { + case 0 => + orderBy.append(SortOrder(UnresolvedAttribute(columnName), direction)) + case 1 => orderBy.append( + SortOrder(UnresolvedAttribute(columnName), direction, NullsFirst, Seq.empty)) + case 2 => orderBy.append( + SortOrder(UnresolvedAttribute(columnName), direction, NullsLast, Seq.empty)) + } + } PythonWorkerUtils.receiveAccumulatorUpdates(maybeAccumulator, dataIn) Option(func.accumulator).foreach(_.merge(maybeAccumulator.get)) @@ -298,7 +333,11 @@ object UserDefinedPythonTableFunction { } releasedOrClosed = true - schema + PythonUDTFAnalyzeResult( + schema = schema, + withSinglePartition = withSinglePartition, + partitionByExpressions = partitionByColumns.toSeq, + orderByExpressions = orderBy.toSeq) } catch { case eof: EOFException => throw new SparkException("Python worker exited unexpectedly (crashed)", eof) diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out index b46a1f230a856..f7b2bada26ecb 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out @@ -1,4 +1,16 @@ -- Automatically generated by SQLQueryTestSuite +-- !query +DROP VIEW IF EXISTS t1 +-- !query analysis +DropTableCommand `spark_catalog`.`default`.`t1`, true, true, false + + +-- !query +DROP VIEW IF EXISTS t2 +-- !query analysis +DropTableCommand `spark_catalog`.`default`.`t2`, true, true, false + + -- !query CREATE OR REPLACE TEMPORARY VIEW t1 AS VALUES (0, 1), (1, 2) t(c1, c2) -- !query analysis @@ -7,6 +19,14 @@ CreateViewCommand `t1`, VALUES (0, 1), (1, 2) t(c1, c2), false, true, LocalTempV +- LocalRelation [c1#x, c2#x] +-- !query +CREATE OR REPLACE TEMPORARY VIEW t2 AS VALUES (0, 1), (1, 2), (1, 3) t(partition_col, input) +-- !query analysis +CreateViewCommand `t2`, VALUES (0, 1), (1, 2), (1, 3) t(partition_col, input), false, true, LocalTempView, true + +- SubqueryAlias t + +- LocalRelation [partition_col#x, input#x] + + -- !query SELECT * FROM udtf(1, 2) -- !query analysis @@ -59,3 +79,380 @@ SELECT * FROM udtf(1, 2) t(c1, c2), LATERAL udtf(c1, c2) SELECT * FROM udtf(cast(rand(0) AS int) + 1, 1) -- !query analysis [Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT * FROM UDTFCountSumLast(TABLE(t2) WITH SINGLE PARTITION) +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT * FROM UDTFCountSumLast(TABLE(t2) PARTITION BY partition_col ORDER BY input) +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT * FROM UDTFCountSumLast(TABLE(t2) PARTITION BY partition_col ORDER BY input DESC) +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT * FROM + VALUES (0), (1) AS t(col) + JOIN LATERAL + UDTFCountSumLast(TABLE(t2) PARTITION BY partition_col ORDER BY input DESC) +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_DETERMINISTIC_LATERAL_SUBQUERIES", + "sqlState" : "0A000", + "messageParameters" : { + "treeNode" : "LateralJoin lateral-subquery#x [], Inner\n: +- Project [count#x, total#x, last#x]\n: +- LateralJoin lateral-subquery#x [c#x], Inner\n: : +- SubqueryAlias __auto_generated_subquery_name_1\n: : +- Generate UDTFCountSumLast(outer(c#x))#x, false, [count#x, total#x, last#x]\n: : +- OneRowRelation\n: +- SubqueryAlias __auto_generated_subquery_name_0\n: +- Project [named_struct(partition_col, partition_col#x, input, input#x, partition_by_0, partition_by_0#x) AS c#x]\n: +- Sort [partition_by_0#x ASC NULLS FIRST, input#x DESC NULLS LAST], false\n: +- RepartitionByExpression [partition_by_0#x]\n: +- Project [partition_col#x, input#x, partition_col#x AS partition_by_0#x]\n: +- SubqueryAlias t2\n: +- View (`t2`, [partition_col#x,input#x])\n: +- Project [cast(partition_col#x as int) AS partition_col#x, cast(input#x as int) AS input#x]\n: +- SubqueryAlias t\n: +- LocalRelation [partition_col#x, input#x]\n+- SubqueryAlias t\n +- LocalRelation [col#x]\n" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 49, + "stopIndex" : 139, + "fragment" : "JOIN LATERAL\n UDTFCountSumLast(TABLE(t2) PARTITION BY partition_col ORDER BY input DESC)" + } ] +} + + +-- !query +SELECT * FROM UDTFWithSinglePartition(TABLE(t2)) +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT * FROM UDTFWithSinglePartition(TABLE(t2) WITH SINGLE PARTITION) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INCOMPATIBLE_WITH_CALL", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "UDTFWithSinglePartition", + "invalidFunctionCallProperty" : "specified the WITH SINGLE PARTITION or PARTITION BY clause; please remove these clauses and retry the query again.", + "requestedMetadata" : "specified its own required partitioning of the input table" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 70, + "fragment" : "UDTFWithSinglePartition(TABLE(t2) WITH SINGLE PARTITION)" + } ] +} + + +-- !query +SELECT * FROM UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INCOMPATIBLE_WITH_CALL", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "UDTFWithSinglePartition", + "invalidFunctionCallProperty" : "specified the WITH SINGLE PARTITION or PARTITION BY clause; please remove these clauses and retry the query again.", + "requestedMetadata" : "specified its own required partitioning of the input table" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 75, + "fragment" : "UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col)" + } ] +} + + +-- !query +SELECT * FROM + VALUES (0), (1) AS t(col) + JOIN LATERAL + UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INCOMPATIBLE_WITH_CALL", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "UDTFWithSinglePartition", + "invalidFunctionCallProperty" : "specified the WITH SINGLE PARTITION or PARTITION BY clause; please remove these clauses and retry the query again.", + "requestedMetadata" : "specified its own required partitioning of the input table" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 66, + "stopIndex" : 126, + "fragment" : "UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col)" + } ] +} + + +-- !query +SELECT * FROM UDTFPartitionByOrderBy(TABLE(t2)) +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT * FROM UDTFPartitionByOrderBy(TABLE(t2) WITH SINGLE PARTITION) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INCOMPATIBLE_WITH_CALL", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "UDTFPartitionByOrderBy", + "invalidFunctionCallProperty" : "specified the WITH SINGLE PARTITION or PARTITION BY clause; please remove these clauses and retry the query again.", + "requestedMetadata" : "specified its own required partitioning of the input table" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 69, + "fragment" : "UDTFPartitionByOrderBy(TABLE(t2) WITH SINGLE PARTITION)" + } ] +} + + +-- !query +SELECT * FROM UDTFPartitionByOrderBy(TABLE(t2) PARTITION BY partition_col) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INCOMPATIBLE_WITH_CALL", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "UDTFPartitionByOrderBy", + "invalidFunctionCallProperty" : "specified the WITH SINGLE PARTITION or PARTITION BY clause; please remove these clauses and retry the query again.", + "requestedMetadata" : "specified its own required partitioning of the input table" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 74, + "fragment" : "UDTFPartitionByOrderBy(TABLE(t2) PARTITION BY partition_col)" + } ] +} + + +-- !query +SELECT * FROM + VALUES (0), (1) AS t(col) + JOIN LATERAL + UDTFPartitionByOrderBy(TABLE(t2) PARTITION BY partition_col) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INCOMPATIBLE_WITH_CALL", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "UDTFPartitionByOrderBy", + "invalidFunctionCallProperty" : "specified the WITH SINGLE PARTITION or PARTITION BY clause; please remove these clauses and retry the query again.", + "requestedMetadata" : "specified its own required partitioning of the input table" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 66, + "stopIndex" : 125, + "fragment" : "UDTFPartitionByOrderBy(TABLE(t2) PARTITION BY partition_col)" + } ] +} + + +-- !query +SELECT * FROM UDTFInvalidPartitionByAndWithSinglePartition(TABLE(t2)) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INVALID", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "UDTFInvalidPartitionByAndWithSinglePartition", + "reason" : "the 'with_single_partition' field cannot be assigned to true if the 'partition_by' list is non-empty" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 69, + "fragment" : "UDTFInvalidPartitionByAndWithSinglePartition(TABLE(t2))" + } ] +} + + +-- !query +SELECT * FROM UDTFInvalidPartitionByAndWithSinglePartition(TABLE(t2) WITH SINGLE PARTITION) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INVALID", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "UDTFInvalidPartitionByAndWithSinglePartition", + "reason" : "the 'with_single_partition' field cannot be assigned to true if the 'partition_by' list is non-empty" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 91, + "fragment" : "UDTFInvalidPartitionByAndWithSinglePartition(TABLE(t2) WITH SINGLE PARTITION)" + } ] +} + + +-- !query +SELECT * FROM UDTFInvalidPartitionByAndWithSinglePartition(TABLE(t2) PARTITION BY partition_col) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INVALID", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "UDTFInvalidPartitionByAndWithSinglePartition", + "reason" : "the 'with_single_partition' field cannot be assigned to true if the 'partition_by' list is non-empty" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 96, + "fragment" : "UDTFInvalidPartitionByAndWithSinglePartition(TABLE(t2) PARTITION BY partition_col)" + } ] +} + + +-- !query +SELECT * FROM + VALUES (0), (1) AS t(col) + JOIN LATERAL + UDTFInvalidPartitionByAndWithSinglePartition(TABLE(t2) PARTITION BY partition_col) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INVALID", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "UDTFInvalidPartitionByAndWithSinglePartition", + "reason" : "the 'with_single_partition' field cannot be assigned to true if the 'partition_by' list is non-empty" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 66, + "stopIndex" : 147, + "fragment" : "UDTFInvalidPartitionByAndWithSinglePartition(TABLE(t2) PARTITION BY partition_col)" + } ] +} + + +-- !query +SELECT * FROM UDTFInvalidOrderByWithoutPartitionBy(TABLE(t2)) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INVALID", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "UDTFInvalidOrderByWithoutPartitionBy", + "reason" : "the 'order_by' field cannot be non-empty unless the 'with_single_partition' field is set to true or the 'partition_by' list is non-empty" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 61, + "fragment" : "UDTFInvalidOrderByWithoutPartitionBy(TABLE(t2))" + } ] +} + + +-- !query +SELECT * FROM UDTFInvalidOrderByWithoutPartitionBy(TABLE(t2) WITH SINGLE PARTITION) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INVALID", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "UDTFInvalidOrderByWithoutPartitionBy", + "reason" : "the 'order_by' field cannot be non-empty unless the 'with_single_partition' field is set to true or the 'partition_by' list is non-empty" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 83, + "fragment" : "UDTFInvalidOrderByWithoutPartitionBy(TABLE(t2) WITH SINGLE PARTITION)" + } ] +} + + +-- !query +SELECT * FROM UDTFInvalidOrderByWithoutPartitionBy(TABLE(t2) PARTITION BY partition_col) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INVALID", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "UDTFInvalidOrderByWithoutPartitionBy", + "reason" : "the 'order_by' field cannot be non-empty unless the 'with_single_partition' field is set to true or the 'partition_by' list is non-empty" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 88, + "fragment" : "UDTFInvalidOrderByWithoutPartitionBy(TABLE(t2) PARTITION BY partition_col)" + } ] +} + + +-- !query +SELECT * FROM + VALUES (0), (1) AS t(col) + JOIN LATERAL + UDTFInvalidOrderByWithoutPartitionBy(TABLE(t2) PARTITION BY partition_col) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INVALID", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "UDTFInvalidOrderByWithoutPartitionBy", + "reason" : "the 'order_by' field cannot be non-empty unless the 'with_single_partition' field is set to true or the 'partition_by' list is non-empty" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 66, + "stopIndex" : 139, + "fragment" : "UDTFInvalidOrderByWithoutPartitionBy(TABLE(t2) PARTITION BY partition_col)" + } ] +} + + +-- !query +DROP VIEW t1 +-- !query analysis +DropTempViewCommand t1 + + +-- !query +DROP VIEW t2 +-- !query analysis +DropTempViewCommand t2 diff --git a/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql b/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql index 66044604d64c0..6d49177c4f6a9 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql @@ -1,4 +1,7 @@ +DROP VIEW IF EXISTS t1; +DROP VIEW IF EXISTS t2; CREATE OR REPLACE TEMPORARY VIEW t1 AS VALUES (0, 1), (1, 2) t(c1, c2); +CREATE OR REPLACE TEMPORARY VIEW t2 AS VALUES (0, 1), (1, 2), (1, 3) t(partition_col, input); -- test basic udtf SELECT * FROM udtf(1, 2); @@ -16,3 +19,91 @@ SELECT * FROM udtf(1, 2) t(c1, c2), LATERAL udtf(c1, c2); -- test non-deterministic input SELECT * FROM udtf(cast(rand(0) AS int) + 1, 1); + +-- test UDTF calls that take input TABLE arguments +-- As a reminder, the UDTFCountSumLast function returns this analyze result: +-- AnalyzeResult( +-- schema=StructType() +-- .add("count", IntegerType()) +-- .add("total", IntegerType()) +-- .add("last", IntegerType())) +SELECT * FROM UDTFCountSumLast(TABLE(t2) WITH SINGLE PARTITION); +SELECT * FROM UDTFCountSumLast(TABLE(t2) PARTITION BY partition_col ORDER BY input); +SELECT * FROM UDTFCountSumLast(TABLE(t2) PARTITION BY partition_col ORDER BY input DESC); +SELECT * FROM + VALUES (0), (1) AS t(col) + JOIN LATERAL + UDTFCountSumLast(TABLE(t2) PARTITION BY partition_col ORDER BY input DESC); + +-- test UDTF calls that take input TABLE arguments and the 'analyze' method returns required +-- partitioning and/or ordering properties for Catalyst to enforce for the input table +-- As a reminder, the UDTFWithSinglePartition function returns this analyze result: +-- AnalyzeResult( +-- schema=StructType() +-- .add("count", IntegerType()) +-- .add("total", IntegerType()) +-- .add("last", IntegerType()), +-- with_single_partition=True, +-- order_by=[ +-- OrderingColumn("input"), +-- OrderingColumn("partition_col")]) +SELECT * FROM UDTFWithSinglePartition(TABLE(t2)); +SELECT * FROM UDTFWithSinglePartition(TABLE(t2) WITH SINGLE PARTITION); +SELECT * FROM UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col); +SELECT * FROM + VALUES (0), (1) AS t(col) + JOIN LATERAL + UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col); +-- As a reminder, the UDTFPartitionByOrderBy function returns this analyze result: +-- AnalyzeResult( +-- schema=StructType() +-- .add("partition_col", IntegerType()) +-- .add("count", IntegerType()) +-- .add("total", IntegerType()) +-- .add("last", IntegerType()), +-- partition_by=[ +-- PartitioningColumn("partition_col") +-- ], +-- order_by=[ +-- OrderingColumn("input") +-- ]) +SELECT * FROM UDTFPartitionByOrderBy(TABLE(t2)); +SELECT * FROM UDTFPartitionByOrderBy(TABLE(t2) WITH SINGLE PARTITION); +SELECT * FROM UDTFPartitionByOrderBy(TABLE(t2) PARTITION BY partition_col); +SELECT * FROM + VALUES (0), (1) AS t(col) + JOIN LATERAL + UDTFPartitionByOrderBy(TABLE(t2) PARTITION BY partition_col); +-- As a reminder, UDTFInvalidPartitionByAndWithSinglePartition returns this analyze result: +-- AnalyzeResult( +-- schema=StructType() +-- .add("last", IntegerType()), +-- with_single_partition=True, +-- partition_by=[ +-- PartitioningColumn("partition_col") +-- ]) +SELECT * FROM UDTFInvalidPartitionByAndWithSinglePartition(TABLE(t2)); +SELECT * FROM UDTFInvalidPartitionByAndWithSinglePartition(TABLE(t2) WITH SINGLE PARTITION); +SELECT * FROM UDTFInvalidPartitionByAndWithSinglePartition(TABLE(t2) PARTITION BY partition_col); +SELECT * FROM + VALUES (0), (1) AS t(col) + JOIN LATERAL + UDTFInvalidPartitionByAndWithSinglePartition(TABLE(t2) PARTITION BY partition_col); +-- As a reminder, UDTFInvalidOrderByWithoutPartitionBy function returns this analyze result: +-- AnalyzeResult( +-- schema=StructType() +-- .add("last", IntegerType()), +-- order_by=[ +-- OrderingColumn("input") +-- ]) +SELECT * FROM UDTFInvalidOrderByWithoutPartitionBy(TABLE(t2)); +SELECT * FROM UDTFInvalidOrderByWithoutPartitionBy(TABLE(t2) WITH SINGLE PARTITION); +SELECT * FROM UDTFInvalidOrderByWithoutPartitionBy(TABLE(t2) PARTITION BY partition_col); +SELECT * FROM + VALUES (0), (1) AS t(col) + JOIN LATERAL + UDTFInvalidOrderByWithoutPartitionBy(TABLE(t2) PARTITION BY partition_col); + +-- cleanup +DROP VIEW t1; +DROP VIEW t2; diff --git a/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out b/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out index 4f91ed3b70e58..a93aac9450156 100644 --- a/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out @@ -1,4 +1,20 @@ -- Automatically generated by SQLQueryTestSuite +-- !query +DROP VIEW IF EXISTS t1 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP VIEW IF EXISTS t2 +-- !query schema +struct<> +-- !query output + + + -- !query CREATE OR REPLACE TEMPORARY VIEW t1 AS VALUES (0, 1), (1, 2) t(c1, c2) -- !query schema @@ -7,6 +23,14 @@ struct<> +-- !query +CREATE OR REPLACE TEMPORARY VIEW t2 AS VALUES (0, 1), (1, 2), (1, 3) t(partition_col, input) +-- !query schema +struct<> +-- !query output + + + -- !query SELECT * FROM udtf(1, 2) -- !query schema @@ -83,3 +107,427 @@ struct -- !query output 1 0 1 0 + + +-- !query +SELECT * FROM UDTFCountSumLast(TABLE(t2) WITH SINGLE PARTITION) +-- !query schema +struct +-- !query output +3 6 3 + + +-- !query +SELECT * FROM UDTFCountSumLast(TABLE(t2) PARTITION BY partition_col ORDER BY input) +-- !query schema +struct +-- !query output +1 1 1 +2 5 3 + + +-- !query +SELECT * FROM UDTFCountSumLast(TABLE(t2) PARTITION BY partition_col ORDER BY input DESC) +-- !query schema +struct +-- !query output +1 1 1 +2 5 2 + + +-- !query +SELECT * FROM + VALUES (0), (1) AS t(col) + JOIN LATERAL + UDTFCountSumLast(TABLE(t2) PARTITION BY partition_col ORDER BY input DESC) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_DETERMINISTIC_LATERAL_SUBQUERIES", + "sqlState" : "0A000", + "messageParameters" : { + "treeNode" : "LateralJoin lateral-subquery#x [], Inner\n: +- Project [count#x, total#x, last#x]\n: +- LateralJoin lateral-subquery#x [c#x], Inner\n: : +- SubqueryAlias __auto_generated_subquery_name_1\n: : +- Generate UDTFCountSumLast(outer(c#x))#x, false, [count#x, total#x, last#x]\n: : +- OneRowRelation\n: +- SubqueryAlias __auto_generated_subquery_name_0\n: +- Project [named_struct(partition_col, partition_col#x, input, input#x, partition_by_0, partition_by_0#x) AS c#x]\n: +- Sort [partition_by_0#x ASC NULLS FIRST, input#x DESC NULLS LAST], false\n: +- RepartitionByExpression [partition_by_0#x]\n: +- Project [partition_col#x, input#x, partition_col#x AS partition_by_0#x]\n: +- SubqueryAlias t2\n: +- View (`t2`, [partition_col#x,input#x])\n: +- Project [cast(partition_col#x as int) AS partition_col#x, cast(input#x as int) AS input#x]\n: +- SubqueryAlias t\n: +- LocalRelation [partition_col#x, input#x]\n+- SubqueryAlias t\n +- LocalRelation [col#x]\n" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 49, + "stopIndex" : 139, + "fragment" : "JOIN LATERAL\n UDTFCountSumLast(TABLE(t2) PARTITION BY partition_col ORDER BY input DESC)" + } ] +} + + +-- !query +SELECT * FROM UDTFWithSinglePartition(TABLE(t2)) +-- !query schema +struct +-- !query output +3 6 3 + + +-- !query +SELECT * FROM UDTFWithSinglePartition(TABLE(t2) WITH SINGLE PARTITION) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INCOMPATIBLE_WITH_CALL", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "UDTFWithSinglePartition", + "invalidFunctionCallProperty" : "specified the WITH SINGLE PARTITION or PARTITION BY clause; please remove these clauses and retry the query again.", + "requestedMetadata" : "specified its own required partitioning of the input table" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 70, + "fragment" : "UDTFWithSinglePartition(TABLE(t2) WITH SINGLE PARTITION)" + } ] +} + + +-- !query +SELECT * FROM UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INCOMPATIBLE_WITH_CALL", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "UDTFWithSinglePartition", + "invalidFunctionCallProperty" : "specified the WITH SINGLE PARTITION or PARTITION BY clause; please remove these clauses and retry the query again.", + "requestedMetadata" : "specified its own required partitioning of the input table" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 75, + "fragment" : "UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col)" + } ] +} + + +-- !query +SELECT * FROM + VALUES (0), (1) AS t(col) + JOIN LATERAL + UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INCOMPATIBLE_WITH_CALL", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "UDTFWithSinglePartition", + "invalidFunctionCallProperty" : "specified the WITH SINGLE PARTITION or PARTITION BY clause; please remove these clauses and retry the query again.", + "requestedMetadata" : "specified its own required partitioning of the input table" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 66, + "stopIndex" : 126, + "fragment" : "UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col)" + } ] +} + + +-- !query +SELECT * FROM UDTFPartitionByOrderBy(TABLE(t2)) +-- !query schema +struct +-- !query output +0 1 1 1 +1 2 5 3 + + +-- !query +SELECT * FROM UDTFPartitionByOrderBy(TABLE(t2) WITH SINGLE PARTITION) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INCOMPATIBLE_WITH_CALL", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "UDTFPartitionByOrderBy", + "invalidFunctionCallProperty" : "specified the WITH SINGLE PARTITION or PARTITION BY clause; please remove these clauses and retry the query again.", + "requestedMetadata" : "specified its own required partitioning of the input table" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 69, + "fragment" : "UDTFPartitionByOrderBy(TABLE(t2) WITH SINGLE PARTITION)" + } ] +} + + +-- !query +SELECT * FROM UDTFPartitionByOrderBy(TABLE(t2) PARTITION BY partition_col) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INCOMPATIBLE_WITH_CALL", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "UDTFPartitionByOrderBy", + "invalidFunctionCallProperty" : "specified the WITH SINGLE PARTITION or PARTITION BY clause; please remove these clauses and retry the query again.", + "requestedMetadata" : "specified its own required partitioning of the input table" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 74, + "fragment" : "UDTFPartitionByOrderBy(TABLE(t2) PARTITION BY partition_col)" + } ] +} + + +-- !query +SELECT * FROM + VALUES (0), (1) AS t(col) + JOIN LATERAL + UDTFPartitionByOrderBy(TABLE(t2) PARTITION BY partition_col) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INCOMPATIBLE_WITH_CALL", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "UDTFPartitionByOrderBy", + "invalidFunctionCallProperty" : "specified the WITH SINGLE PARTITION or PARTITION BY clause; please remove these clauses and retry the query again.", + "requestedMetadata" : "specified its own required partitioning of the input table" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 66, + "stopIndex" : 125, + "fragment" : "UDTFPartitionByOrderBy(TABLE(t2) PARTITION BY partition_col)" + } ] +} + + +-- !query +SELECT * FROM UDTFInvalidPartitionByAndWithSinglePartition(TABLE(t2)) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INVALID", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "UDTFInvalidPartitionByAndWithSinglePartition", + "reason" : "the 'with_single_partition' field cannot be assigned to true if the 'partition_by' list is non-empty" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 69, + "fragment" : "UDTFInvalidPartitionByAndWithSinglePartition(TABLE(t2))" + } ] +} + + +-- !query +SELECT * FROM UDTFInvalidPartitionByAndWithSinglePartition(TABLE(t2) WITH SINGLE PARTITION) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INVALID", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "UDTFInvalidPartitionByAndWithSinglePartition", + "reason" : "the 'with_single_partition' field cannot be assigned to true if the 'partition_by' list is non-empty" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 91, + "fragment" : "UDTFInvalidPartitionByAndWithSinglePartition(TABLE(t2) WITH SINGLE PARTITION)" + } ] +} + + +-- !query +SELECT * FROM UDTFInvalidPartitionByAndWithSinglePartition(TABLE(t2) PARTITION BY partition_col) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INVALID", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "UDTFInvalidPartitionByAndWithSinglePartition", + "reason" : "the 'with_single_partition' field cannot be assigned to true if the 'partition_by' list is non-empty" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 96, + "fragment" : "UDTFInvalidPartitionByAndWithSinglePartition(TABLE(t2) PARTITION BY partition_col)" + } ] +} + + +-- !query +SELECT * FROM + VALUES (0), (1) AS t(col) + JOIN LATERAL + UDTFInvalidPartitionByAndWithSinglePartition(TABLE(t2) PARTITION BY partition_col) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INVALID", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "UDTFInvalidPartitionByAndWithSinglePartition", + "reason" : "the 'with_single_partition' field cannot be assigned to true if the 'partition_by' list is non-empty" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 66, + "stopIndex" : 147, + "fragment" : "UDTFInvalidPartitionByAndWithSinglePartition(TABLE(t2) PARTITION BY partition_col)" + } ] +} + + +-- !query +SELECT * FROM UDTFInvalidOrderByWithoutPartitionBy(TABLE(t2)) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INVALID", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "UDTFInvalidOrderByWithoutPartitionBy", + "reason" : "the 'order_by' field cannot be non-empty unless the 'with_single_partition' field is set to true or the 'partition_by' list is non-empty" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 61, + "fragment" : "UDTFInvalidOrderByWithoutPartitionBy(TABLE(t2))" + } ] +} + + +-- !query +SELECT * FROM UDTFInvalidOrderByWithoutPartitionBy(TABLE(t2) WITH SINGLE PARTITION) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INVALID", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "UDTFInvalidOrderByWithoutPartitionBy", + "reason" : "the 'order_by' field cannot be non-empty unless the 'with_single_partition' field is set to true or the 'partition_by' list is non-empty" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 83, + "fragment" : "UDTFInvalidOrderByWithoutPartitionBy(TABLE(t2) WITH SINGLE PARTITION)" + } ] +} + + +-- !query +SELECT * FROM UDTFInvalidOrderByWithoutPartitionBy(TABLE(t2) PARTITION BY partition_col) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INVALID", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "UDTFInvalidOrderByWithoutPartitionBy", + "reason" : "the 'order_by' field cannot be non-empty unless the 'with_single_partition' field is set to true or the 'partition_by' list is non-empty" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 88, + "fragment" : "UDTFInvalidOrderByWithoutPartitionBy(TABLE(t2) PARTITION BY partition_col)" + } ] +} + + +-- !query +SELECT * FROM + VALUES (0), (1) AS t(col) + JOIN LATERAL + UDTFInvalidOrderByWithoutPartitionBy(TABLE(t2) PARTITION BY partition_col) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INVALID", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "UDTFInvalidOrderByWithoutPartitionBy", + "reason" : "the 'order_by' field cannot be non-empty unless the 'with_single_partition' field is set to true or the 'partition_by' list is non-empty" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 66, + "stopIndex" : 139, + "fragment" : "UDTFInvalidOrderByWithoutPartitionBy(TABLE(t2) PARTITION BY partition_col)" + } ] +} + + +-- !query +DROP VIEW t1 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP VIEW t2 +-- !query schema +struct<> +-- !query output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index 508818c2e501d..05f71500a0f52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -320,7 +320,9 @@ object IntegratedUDFTestUtils extends SQLHelper { sealed trait TestUDTF { def apply(session: SparkSession, exprs: Column*): DataFrame + val name: String val prettyName: String + val udtf: UserDefinedPythonTableFunction } class PythonUDFWithoutId( @@ -395,7 +397,7 @@ object IntegratedUDFTestUtils extends SQLHelper { def createUserDefinedPythonTableFunction( name: String, pythonScript: String, - returnType: StructType, + returnType: Option[StructType], evalType: Int = PythonEvalType.SQL_TABLE_UDF, deterministic: Boolean = false): UserDefinedPythonTableFunction = { UserDefinedPythonTableFunction( @@ -408,13 +410,13 @@ object IntegratedUDFTestUtils extends SQLHelper { pythonVer = pythonVer, broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, accumulator = null), - returnType = Some(returnType), + returnType = returnType, pythonEvalType = evalType, udfDeterministic = deterministic) } case class TestPythonUDTF(name: String) extends TestUDTF { - private val pythonScript: String = + val pythonScript: String = """ |class TestUDTF: | def eval(self, a: int, b: int): @@ -427,10 +429,10 @@ object IntegratedUDFTestUtils extends SQLHelper { | ... |""".stripMargin - private[IntegratedUDFTestUtils] lazy val udtf = createUserDefinedPythonTableFunction( + val udtf: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( name = "TestUDTF", pythonScript = pythonScript, - returnType = StructType.fromDDL("x int, y int") + returnType = Some(StructType.fromDDL("x int, y int")) ) def apply(session: SparkSession, exprs: Column*): DataFrame = @@ -439,6 +441,255 @@ object IntegratedUDFTestUtils extends SQLHelper { val prettyName: String = "Regular Python UDTF" } + object TestPythonUDTFCountSumLast extends TestUDTF { + val name: String = "UDTFCountSumLast" + val pythonScript: String = + s""" + |from pyspark.sql.functions import AnalyzeResult, OrderingColumn, PartitioningColumn + |from pyspark.sql.types import IntegerType, Row, StructType + |class $name: + | def __init__(self): + | self._count = 0 + | self._sum = 0 + | self._last = None + | + | @staticmethod + | def analyze(self): + | return AnalyzeResult( + | schema=StructType() + | .add("count", IntegerType()) + | .add("total", IntegerType()) + | .add("last", IntegerType())) + | + | def eval(self, row: Row): + | self._count += 1 + | self._last = row["input"] + | self._sum += row["input"] + | + | def terminate(self): + | yield self._count, self._sum, self._last + |""".stripMargin + + val udtf: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( + name = name, + pythonScript = pythonScript, + returnType = None) + + def apply(session: SparkSession, exprs: Column*): DataFrame = + udtf.apply(session, exprs: _*) + + val prettyName: String = + "Python UDTF finding the count, sum, and last value from the input rows" + } + + object TestPythonUDTFLastString extends TestUDTF { + val name: String = "UDTFLastString" + val pythonScript: String = + s""" + |from pyspark.sql.functions import AnalyzeResult + |from pyspark.sql.types import Row, StringType, StructType + |class $name: + | def __init__(self): + | self._last = "" + | + | @staticmethod + | def analyze(self): + | return AnalyzeResult( + | schema=StructType() + | .add("last", StringType())) + | + | def eval(self, row: Row): + | self._last = row["input"] + | + | def terminate(self): + | yield self._last, + |""".stripMargin + + val udtf: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( + name = name, + pythonScript = pythonScript, + returnType = None) + + def apply(session: SparkSession, exprs: Column*): DataFrame = + udtf.apply(session, exprs: _*) + + val prettyName: String = "Python UDTF returning the last string provided in the input table" + } + + + object TestPythonUDTFWithSinglePartition extends TestUDTF { + val name: String = "UDTFWithSinglePartition" + val pythonScript: String = + s""" + |from pyspark.sql.functions import AnalyzeResult, OrderingColumn, PartitioningColumn + |from pyspark.sql.types import IntegerType, Row, StructType + |class $name: + | def __init__(self): + | self._count = 0 + | self._sum = 0 + | self._last = None + | + | @staticmethod + | def analyze(self): + | return AnalyzeResult( + | schema=StructType() + | .add("count", IntegerType()) + | .add("total", IntegerType()) + | .add("last", IntegerType()), + | with_single_partition=True, + | order_by=[ + | OrderingColumn("input"), + | OrderingColumn("partition_col")]) + | + | def eval(self, row: Row): + | self._count += 1 + | self._last = row["input"] + | self._sum += row["input"] + | + | def terminate(self): + | yield self._count, self._sum, self._last + |""".stripMargin + + val udtf: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( + name = name, + pythonScript = pythonScript, + returnType = None) + + def apply(session: SparkSession, exprs: Column*): DataFrame = + udtf.apply(session, exprs: _*) + + val prettyName: String = "Python UDTF exporting single-partition requirement from 'analyze'" + } + + object TestPythonUDTFPartitionBy extends TestUDTF { + val name: String = "UDTFPartitionByOrderBy" + val pythonScript: String = + s""" + |from pyspark.sql.functions import AnalyzeResult, OrderingColumn, PartitioningColumn + |from pyspark.sql.types import IntegerType, Row, StructType + |class $name: + | def __init__(self): + | self._partition_col = None + | self._count = 0 + | self._sum = 0 + | self._last = None + | + | @staticmethod + | def analyze(self): + | return AnalyzeResult( + | schema=StructType() + | .add("partition_col", IntegerType()) + | .add("count", IntegerType()) + | .add("total", IntegerType()) + | .add("last", IntegerType()), + | partition_by=[ + | PartitioningColumn("partition_col") + | ], + | order_by=[ + | OrderingColumn("input") + | ]) + | + | def eval(self, row: Row): + | self._partition_col = row["partition_col"] + | self._count += 1 + | self._last = row["input"] + | self._sum += row["input"] + | + | def terminate(self): + | yield self._partition_col, self._count, self._sum, self._last + |""".stripMargin + + val udtf: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( + name = name, + pythonScript = pythonScript, + returnType = None) + + def apply(session: SparkSession, exprs: Column*): DataFrame = + udtf.apply(session, exprs: _*) + + val prettyName: String = + "Python UDTF exporting input table partitioning and ordering requirement from 'analyze'" + } + + object TestPythonUDTFInvalidPartitionByAndWithSinglePartition extends TestUDTF { + val name: String = "UDTFInvalidPartitionByAndWithSinglePartition" + val pythonScript: String = + s""" + |from pyspark.sql.functions import AnalyzeResult, OrderingColumn, PartitioningColumn + |from pyspark.sql.types import IntegerType, Row, StructType + |class $name: + | def __init__(self): + | self._last = None + | + | @staticmethod + | def analyze(self): + | return AnalyzeResult( + | schema=StructType() + | .add("last", IntegerType()), + | with_single_partition=True, + | partition_by=[ + | PartitioningColumn("partition_col") + | ]) + | + | def eval(self, row: Row): + | self._last = row["input"] + | + | def terminate(self): + | yield self._last, + |""".stripMargin + + val udtf: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( + name = name, + pythonScript = pythonScript, + returnType = None) + + def apply(session: SparkSession, exprs: Column*): DataFrame = + udtf.apply(session, exprs: _*) + + val prettyName: String = + "Python UDTF exporting invalid input table partitioning requirement from 'analyze' " + + "because the 'with_single_partition' property is also exported to true" + } + + object TestPythonUDTFInvalidOrderByWithoutPartitionBy extends TestUDTF { + val name: String = "UDTFInvalidOrderByWithoutPartitionBy" + val pythonScript: String = + s""" + |from pyspark.sql.functions import AnalyzeResult, OrderingColumn, PartitioningColumn + |from pyspark.sql.types import IntegerType, Row, StructType + |class $name: + | def __init__(self): + | self._last = None + | + | @staticmethod + | def analyze(self): + | return AnalyzeResult( + | schema=StructType() + | .add("last", IntegerType()), + | order_by=[ + | OrderingColumn("input") + | ]) + | + | def eval(self, row: Row): + | self._last = row["input"] + | + | def terminate(self): + | yield self._last, + |""".stripMargin + + val udtf: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( + name = name, + pythonScript = pythonScript, + returnType = None) + + def apply(session: SparkSession, exprs: Column*): DataFrame = + udtf.apply(session, exprs: _*) + + val prettyName: String = + "Python UDTF exporting invalid input table ordering requirement from 'analyze' " + + "without a corresponding partitioning table requirement" + } + /** * A Scalar Pandas UDF that takes one column, casts into string, executes the * Python native function, and casts back to the type of input column. @@ -622,8 +873,13 @@ object IntegratedUDFTestUtils extends SQLHelper { /** * Register UDTFs used in the test cases. */ - def registerTestUDTF(testUDTF: TestUDTF, session: SparkSession): Unit = testUDTF match { - case udtf: TestPythonUDTF => session.udtf.registerPython(udtf.name, udtf.udtf) - case other => throw new RuntimeException(s"Unknown UDTF class [${other.getClass}]") + case class TestUDTFSet(udtfs: Seq[TestUDTF]) + def registerTestUDTFs(testUDTFSet: TestUDTFSet, session: SparkSession): Unit = { + testUDTFSet.udtfs.foreach { + _ match { + case udtf: TestUDTF => session.udtf.registerPython(udtf.name, udtf.udtf) + case other => throw new RuntimeException(s"Unknown UDTF class [${other.getClass}]") + } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 0b27f80a5d02f..36899d1157833 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -202,8 +202,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper val udf: TestUDF } - protected trait UDTFTest { - val udtf: TestUDTF + protected trait UDTFSetTest { + val udtfSet: TestUDTFSet } /** A regular test case. */ @@ -241,14 +241,14 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper UDFAnalyzerTestCase(newName, inputFile, newResultFile, udf) } - protected case class UDTFTestCase( + protected case class UDTFSetTestCase( name: String, inputFile: String, resultFile: String, - udtf: TestUDTF) extends TestCase with UDTFTest { + udtfSet: TestUDTFSet) extends TestCase with UDTFSetTest { override def asAnalyzerTest(newName: String, newResultFile: String): TestCase = - UDTFAnalyzerTestCase(newName, inputFile, newResultFile, udtf) + UDTFSetAnalyzerTestCase(newName, inputFile, newResultFile, udtfSet) } /** A UDAF test case. */ @@ -299,9 +299,9 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper protected case class UDFAnalyzerTestCase( name: String, inputFile: String, resultFile: String, udf: TestUDF) extends AnalyzerTest with UDFTest - protected case class UDTFAnalyzerTestCase( - name: String, inputFile: String, resultFile: String, udtf: TestUDTF) - extends AnalyzerTest with UDTFTest + protected case class UDTFSetAnalyzerTestCase( + name: String, inputFile: String, resultFile: String, udtfSet: TestUDTFSet) + extends AnalyzerTest with UDTFSetTest protected case class UDAFAnalyzerTestCase( name: String, inputFile: String, resultFile: String, udf: TestUDF) extends AnalyzerTest with UDFTest @@ -505,8 +505,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper testCase match { case udfTestCase: UDFTest => registerTestUDF(udfTestCase.udf, localSparkSession) - case udtfTestCase: UDTFTest => - registerTestUDTF(udtfTestCase.udtf, localSparkSession) + case udtfTestCase: UDTFSetTest => + registerTestUDTFs(udtfTestCase.udtfSet, localSparkSession) case _ => } @@ -593,8 +593,9 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper shouldTestPandasUDFs => s"${testCase.name}${System.lineSeparator()}" + s"Python: $pythonVer Pandas: $pandasVer PyArrow: $pyarrowVer${System.lineSeparator()}" - case udtfTestCase: UDTFTest - if udtfTestCase.udtf.isInstanceOf[TestPythonUDTF] && shouldTestPythonUDFs => + case udtfTestCase: UDTFSetTest + if udtfTestCase.udtfSet.udtfs.forall(_.isInstanceOf[TestPythonUDTF]) && + shouldTestPythonUDFs => s"${testCase.name}${System.lineSeparator()}Python: $pythonVer${System.lineSeparator()}" case _ => s"${testCase.name}${System.lineSeparator()}" @@ -643,10 +644,17 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper s"$testCaseName - ${udf.prettyName}", absPath, resultFile, udf) } } else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}udtf")) { - Seq(TestPythonUDTF("udtf")).map { udtf => - UDTFTestCase( - s"$testCaseName - ${udtf.prettyName}", absPath, resultFile, udtf - ) + Seq(TestUDTFSet(Seq( + TestPythonUDTF("udtf"), + TestPythonUDTFCountSumLast, + TestPythonUDTFLastString, + TestPythonUDTFWithSinglePartition, + TestPythonUDTFPartitionBy, + TestPythonUDTFInvalidPartitionByAndWithSinglePartition, + TestPythonUDTFInvalidOrderByWithoutPartitionBy + ))).map { udtfSet => + UDTFSetTestCase( + s"$testCaseName - Python UDTFs", absPath, resultFile, udtfSet) } } else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}postgreSQL")) { PgSQLTestCase(testCaseName, absPath, resultFile) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala index 7e5ac53975107..cf687f902871f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala @@ -44,13 +44,25 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { private val returnType: StructType = StructType.fromDDL("a int, b int, c int") private val pythonUDTF: UserDefinedPythonTableFunction = - createUserDefinedPythonTableFunction("SimpleUDTF", pythonScript, returnType) + createUserDefinedPythonTableFunction("SimpleUDTF", pythonScript, Some(returnType)) + + private val pythonUDTFCountSumLast: UserDefinedPythonTableFunction = + createUserDefinedPythonTableFunction( + "UDTFCountSumLast", TestPythonUDTFCountSumLast.pythonScript, None) + + private val pythonUDTFWithSinglePartition: UserDefinedPythonTableFunction = + createUserDefinedPythonTableFunction( + "UDTFWithSinglePartition", TestPythonUDTFWithSinglePartition.pythonScript, None) + + private val pythonUDTFPartitionByOrderBy: UserDefinedPythonTableFunction = + createUserDefinedPythonTableFunction( + "UDTFPartitionByOrderBy", TestPythonUDTFPartitionBy.pythonScript, None) private val arrowPythonUDTF: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( "SimpleUDTF", pythonScript, - returnType, + Some(returnType), evalType = PythonEvalType.SQL_ARROW_TABLE_UDF) test("Simple PythonUDTF") { @@ -189,6 +201,58 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { start = 14, stop = 30)) } + + spark.udtf.registerPython("UDTFCountSumLast", pythonUDTFCountSumLast) + var plan = sql( + """ + |WITH t AS ( + | VALUES (0, 1), (1, 2), (1, 3) t(partition_col, input) + |) + |SELECT count, total, last + |FROM UDTFCountSumLast(TABLE(t) WITH SINGLE PARTITION) + |ORDER BY 1, 2 + |""".stripMargin).queryExecution.analyzed + plan.collectFirst { case r: Repartition => r } match { + case Some(Repartition(1, true, _)) => + case _ => + failure(plan) + } + + spark.udtf.registerPython("UDTFWithSinglePartition", pythonUDTFWithSinglePartition) + plan = sql( + """ + |WITH t AS ( + | SELECT id AS partition_col, 1 AS input FROM range(1, 21) + | UNION ALL + | SELECT id AS partition_col, 2 AS input FROM range(1, 21) + |) + |SELECT count, total, last + |FROM UDTFWithSinglePartition(TABLE(t)) + |ORDER BY 1, 2 + |""".stripMargin).queryExecution.analyzed + plan.collectFirst { case r: Repartition => r } match { + case Some(Repartition(1, true, _)) => + case _ => + failure(plan) + } + + spark.udtf.registerPython("UDTFPartitionByOrderBy", pythonUDTFPartitionByOrderBy) + plan = sql( + """ + |WITH t AS ( + | SELECT id AS partition_col, 1 AS input FROM range(1, 21) + | UNION ALL + | SELECT id AS partition_col, 2 AS input FROM range(1, 21) + |) + |SELECT partition_col, count, total, last + |FROM UDTFPartitionByOrderBy(TABLE(t)) + |ORDER BY 1, 2 + |""".stripMargin).queryExecution.analyzed + plan.collectFirst { case r: RepartitionByExpression => r } match { + case Some(_: RepartitionByExpression) => + case _ => + failure(plan) + } } test("SPARK-44503: Compute partition child indexes for various UDTF argument lists") { From e9962e89335843fb6be3d0ab8ddef6a3667cc0c3 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Fri, 1 Sep 2023 20:22:34 -0700 Subject: [PATCH 21/35] [SPARK-45054][SQL] HiveExternalCatalog.listPartitions should restore partition statistics ### What changes were proposed in this pull request? Call `restorePartitionMetadata` in `listPartitions` to restore Spark SQL statistics. ### Why are the changes needed? Currently when `listPartitions` is called, it doesn't restore Spark SQL statistics stored in metastore, such as `spark.sql.statistics.totalSize`. This means callers who rely on stats from the method call may wrong results. In particular, when `spark.sql.statistics.size.autoUpdate.enabled` is turned on, during insert overwrite Spark will first list partitions and get old statistics, and then compare them with new statistics and see which partitions need to be updated. This issue will sometimes cause it to update all partitions instead of only those partitions that have been touched. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added a new test. ### Was this patch authored or co-authored using generative AI tooling? Closes #42777 from sunchao/list-partition-stat. Authored-by: Chao Sun Signed-off-by: Chao Sun --- .../catalog/ExternalCatalogSuite.scala | 25 +++++++++++++++++++ .../spark/sql/hive/HiveExternalCatalog.scala | 7 ++++-- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index 32eb884942763..a8f73cebf31e4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -475,6 +475,31 @@ abstract class ExternalCatalogSuite extends SparkFunSuite { assert(catalog.listPartitions("db2", "tbl2", Some(Map("a" -> "unknown"))).isEmpty) } + test("SPARK-45054: list partitions should restore stats") { + val catalog = newBasicCatalog() + val stats = Some(CatalogStatistics(sizeInBytes = 1)) + val newPart = CatalogTablePartition(Map("a" -> "1", "b" -> "2"), storageFormat, stats = stats) + catalog.alterPartitions("db2", "tbl2", Seq(newPart)) + val parts = catalog.listPartitions("db2", "tbl2", Some(Map("a" -> "1"))) + + assert(parts.length == 1) + val part = parts.head + assert(part.stats.exists(_.sizeInBytes == 1)) + } + + test("SPARK-45054: list partitions by filter should restore stats") { + val catalog = newBasicCatalog() + val stats = Some(CatalogStatistics(sizeInBytes = 1)) + val newPart = CatalogTablePartition(Map("a" -> "1", "b" -> "2"), storageFormat, stats = stats) + catalog.alterPartitions("db2", "tbl2", Seq(newPart)) + val tz = TimeZone.getDefault.getID + val parts = catalog.listPartitionsByFilter("db2", "tbl2", Seq($"a".int === 1), tz) + + assert(parts.length == 1) + val part = parts.head + assert(part.stats.exists(_.sizeInBytes == 1)) + } + test("SPARK-21457: list partitions with special chars") { val catalog = newBasicCatalog() assert(catalog.listPartitions("db2", "tbl1").isEmpty) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 67b780f13c431..e4325989b7066 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -1275,13 +1275,14 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat db: String, table: String, partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = withClient { - val partColNameMap = buildLowerCasePartColNameMap(getTable(db, table)) + val catalogTable = getTable(db, table) + val partColNameMap = buildLowerCasePartColNameMap(catalogTable) val metaStoreSpec = partialSpec.map(toMetaStorePartitionSpec) val res = client.getPartitions(db, table, metaStoreSpec) .map { part => part.copy(spec = restorePartitionSpec(part.spec, partColNameMap)) } - metaStoreSpec match { + val parts = metaStoreSpec match { // This might be a bug of Hive: When the partition value inside the partial partition spec // contains dot, and we ask Hive to list partitions w.r.t. the partial partition spec, Hive // treats dot as matching any single character and may return more partitions than we @@ -1290,6 +1291,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat res.filter(p => isPartialPartitionSpec(spec, toMetaStorePartitionSpec(p.spec))) case _ => res } + parts.map(restorePartitionMetadata(_, catalogTable)) } override def listPartitionsByFilter( @@ -1303,6 +1305,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val clientPrunedPartitions = client.getPartitionsByFilter(rawHiveTable, predicates).map { part => part.copy(spec = restorePartitionSpec(part.spec, partColNameMap)) + restorePartitionMetadata(part, catalogTable) } prunePartitionsByFilter(catalogTable, clientPrunedPartitions, predicates, defaultTimeZoneId) } From f0fb434c268f69e6845ba97e3256d3c1b873fc95 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Sat, 2 Sep 2023 17:20:22 +0800 Subject: [PATCH 22/35] [SPARK-45026][CONNECT][FOLLOW-UP] Code cleanup ### What changes were proposed in this pull request? move 3 variables to `isCommand` branch ### Why are the changes needed? they are not used in other branches ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI ### Was this patch authored or co-authored using generative AI tooling? NO Closes #42765 from zhengruifeng/SPARK-45026-followup. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- .../sql/connect/planner/SparkConnectPlanner.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 547b6a9fb4039..11300631491d9 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -2464,15 +2464,15 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { case _ => Seq.empty } - // Convert the results to Arrow. - val schema = df.schema - val maxBatchSize = (SparkEnv.get.conf.get(CONNECT_GRPC_ARROW_MAX_BATCH_SIZE) * 0.7).toLong - val timeZoneId = session.sessionState.conf.sessionLocalTimeZone - // To avoid explicit handling of the result on the client, we build the expected input // of the relation on the server. The client has to simply forward the result. val result = SqlCommandResult.newBuilder() if (isCommand) { + // Convert the results to Arrow. + val schema = df.schema + val maxBatchSize = (SparkEnv.get.conf.get(CONNECT_GRPC_ARROW_MAX_BATCH_SIZE) * 0.7).toLong + val timeZoneId = session.sessionState.conf.sessionLocalTimeZone + // Convert the data. val bytes = if (rows.isEmpty) { ArrowConverters.createEmptyArrowBatch( From 82d54fc8924618777992ee9a4d939b1fb336f20d Mon Sep 17 00:00:00 2001 From: panbingkun Date: Sat, 2 Sep 2023 08:18:43 -0500 Subject: [PATCH 23/35] [SPARK-45043][BUILD] Upgrade `scalafmt` to 3.7.13 ### What changes were proposed in this pull request? The pr aims to upgrade `scalafmt` from 3.7.5 to 3.7.13. ### Why are the changes needed? 1.The newest version include some bug fixed, eg: - FormatWriter: accumulate align shift correctly (https://github.com/scalameta/scalafmt/pull/3615) - Indents: ignore fewerBraces if indentation is 1 (https://github.com/scalameta/scalafmt/pull/3592) - RemoveScala3OptionalBraces: handle infix on rbrace (https://github.com/scalameta/scalafmt/pull/3576) 2.The full release notes: https://github.com/scalameta/scalafmt/releases/tag/v3.7.13 https://github.com/scalameta/scalafmt/releases/tag/v3.7.12 https://github.com/scalameta/scalafmt/releases/tag/v3.7.11 https://github.com/scalameta/scalafmt/releases/tag/v3.7.10 https://github.com/scalameta/scalafmt/releases/tag/v3.7.9 https://github.com/scalameta/scalafmt/releases/tag/v3.7.8 https://github.com/scalameta/scalafmt/releases/tag/v3.7.7 https://github.com/scalameta/scalafmt/releases/tag/v3.7.6 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #42764 from panbingkun/SPARK-45043. Authored-by: panbingkun Signed-off-by: Sean Owen --- dev/.scalafmt.conf | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/.scalafmt.conf b/dev/.scalafmt.conf index c3b26002a7690..721dec289900b 100644 --- a/dev/.scalafmt.conf +++ b/dev/.scalafmt.conf @@ -32,4 +32,4 @@ fileOverride { runner.dialect = scala213 } } -version = 3.7.5 +version = 3.7.13 From 967aac1171a49c8e98c992512487d77c2b1c4565 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Sat, 2 Sep 2023 08:19:38 -0500 Subject: [PATCH 24/35] [SPARK-44956][BUILD] Upgrade Jekyll to 4.3.2 & Webrick to 1.8.1 ### What changes were proposed in this pull request? The pr aims to upgrade - Jekyll from 4.2.1 to 4.3.2. - Webrick from 1.7 to 1.8.1. ### Why are the changes needed? 1.The `4.2.1` version was released on Sep 27, 2021, and it has been 2 years since now. 2.Jekyll 4.3.2 was released in `Jan 21, 2023`, which includes the fix of a regression bug. - https://github.com/jekyll/jekyll/releases/tag/v4.3.2 - https://github.com/jekyll/jekyll/releases/tag/v4.3.1 - https://github.com/jekyll/jekyll/releases/tag/v4.3.0 Fix regression in Convertible module from v4.2.0 (https://github.com/jekyll/jekyll/pull/8786) - https://github.com/jekyll/jekyll/releases/tag/v4.2.2 3.The webrick newest version include some big fixed. https://github.com/ruby/webrick/releases/tag/v1.8.1 https://github.com/ruby/webrick/releases/tag/v1.8.0 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - Pass GA. - Manually test. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #42669 from panbingkun/SPARK-44956. Authored-by: panbingkun Signed-off-by: Sean Owen --- docs/Gemfile | 4 +-- docs/Gemfile.lock | 62 +++++++++++++++++++++++++---------------------- 2 files changed, 35 insertions(+), 31 deletions(-) diff --git a/docs/Gemfile b/docs/Gemfile index 6c35201296480..6c6760371163c 100644 --- a/docs/Gemfile +++ b/docs/Gemfile @@ -18,7 +18,7 @@ source "https://rubygems.org" gem "ffi", "1.15.5" -gem "jekyll", "4.2.1" +gem "jekyll", "4.3.2" gem "rouge", "3.26.0" gem "jekyll-redirect-from", "0.16.0" -gem "webrick", "1.7" +gem "webrick", "1.8.1" diff --git a/docs/Gemfile.lock b/docs/Gemfile.lock index 6654e6c47c615..eda31f857476e 100644 --- a/docs/Gemfile.lock +++ b/docs/Gemfile.lock @@ -1,74 +1,78 @@ GEM remote: https://rubygems.org/ specs: - addressable (2.8.0) - public_suffix (>= 2.0.2, < 5.0) + addressable (2.8.5) + public_suffix (>= 2.0.2, < 6.0) colorator (1.1.0) - concurrent-ruby (1.1.9) - em-websocket (0.5.2) + concurrent-ruby (1.2.2) + em-websocket (0.5.3) eventmachine (>= 0.12.9) - http_parser.rb (~> 0.6.0) + http_parser.rb (~> 0) eventmachine (1.2.7) ffi (1.15.5) forwardable-extended (2.6.0) - http_parser.rb (0.6.0) - i18n (1.8.11) + google-protobuf (3.24.2) + http_parser.rb (0.8.0) + i18n (1.14.1) concurrent-ruby (~> 1.0) - jekyll (4.2.1) + jekyll (4.3.2) addressable (~> 2.4) colorator (~> 1.0) em-websocket (~> 0.5) i18n (~> 1.0) - jekyll-sass-converter (~> 2.0) + jekyll-sass-converter (>= 2.0, < 4.0) jekyll-watch (~> 2.0) - kramdown (~> 2.3) + kramdown (~> 2.3, >= 2.3.1) kramdown-parser-gfm (~> 1.0) liquid (~> 4.0) - mercenary (~> 0.4.0) + mercenary (>= 0.3.6, < 0.5) pathutil (~> 0.9) - rouge (~> 3.0) + rouge (>= 3.0, < 5.0) safe_yaml (~> 1.0) - terminal-table (~> 2.0) + terminal-table (>= 1.8, < 4.0) + webrick (~> 1.7) jekyll-redirect-from (0.16.0) jekyll (>= 3.3, < 5.0) - jekyll-sass-converter (2.1.0) - sassc (> 2.0.1, < 3.0) + jekyll-sass-converter (3.0.0) + sass-embedded (~> 1.54) jekyll-watch (2.2.1) listen (~> 3.0) - kramdown (2.3.1) + kramdown (2.4.0) rexml kramdown-parser-gfm (1.1.0) kramdown (~> 2.0) - liquid (4.0.3) - listen (3.7.0) + liquid (4.0.4) + listen (3.8.0) rb-fsevent (~> 0.10, >= 0.10.3) rb-inotify (~> 0.9, >= 0.9.10) mercenary (0.4.0) pathutil (0.16.2) forwardable-extended (~> 2.6) - public_suffix (4.0.6) - rb-fsevent (0.11.0) + public_suffix (5.0.3) + rake (13.0.6) + rb-fsevent (0.11.2) rb-inotify (0.10.1) ffi (~> 1.0) - rexml (3.2.5) + rexml (3.2.6) rouge (3.26.0) safe_yaml (1.0.5) - sassc (2.4.0) - ffi (~> 1.9) - terminal-table (2.0.0) - unicode-display_width (~> 1.1, >= 1.1.1) - unicode-display_width (1.8.0) - webrick (1.7.0) + sass-embedded (1.63.6) + google-protobuf (~> 3.23) + rake (>= 13.0.0) + terminal-table (3.0.2) + unicode-display_width (>= 1.1.1, < 3) + unicode-display_width (2.4.2) + webrick (1.8.1) PLATFORMS ruby DEPENDENCIES ffi (= 1.15.5) - jekyll (= 4.2.1) + jekyll (= 4.3.2) jekyll-redirect-from (= 0.16.0) rouge (= 3.26.0) - webrick (= 1.7) + webrick (= 1.8.1) BUNDLED WITH 2.3.8 From 3e22c8653d728a6b8523051faddcca437accfc22 Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Sat, 2 Sep 2023 16:07:09 -0700 Subject: [PATCH 25/35] [SPARK-44640][PYTHON][FOLLOW-UP] Update UDTF error messages to include method name ### What changes were proposed in this pull request? This PR is a follow-up for SPARK-44640 to make the error message of a few UDTF errors more informative by including the method name in the error message (`eval` or `terminate`). ### Why are the changes needed? To improve error messages. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #42726 from allisonwang-db/SPARK-44640-follow-up. Authored-by: allisonwang-db Signed-off-by: Takuya UESHIN --- python/pyspark/errors/error_classes.py | 8 +++--- python/pyspark/sql/tests/test_udtf.py | 21 +++++++++++++++ python/pyspark/worker.py | 37 +++++++++++++++++++------- 3 files changed, 52 insertions(+), 14 deletions(-) diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index ca448a169e83b..74f52c416e95b 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -244,7 +244,7 @@ }, "INVALID_ARROW_UDTF_RETURN_TYPE" : { "message" : [ - "The return type of the arrow-optimized Python UDTF should be of type 'pandas.DataFrame', but the function returned a value of type with value: ." + "The return type of the arrow-optimized Python UDTF should be of type 'pandas.DataFrame', but the '' method returned a value of type with value: ." ] }, "INVALID_BROADCAST_OPERATION": { @@ -745,17 +745,17 @@ }, "UDTF_INVALID_OUTPUT_ROW_TYPE" : { "message" : [ - "The type of an individual output row in the UDTF is invalid. Each row should be a tuple, list, or dict, but got ''. Please make sure that the output rows are of the correct type." + "The type of an individual output row in the '' method of the UDTF is invalid. Each row should be a tuple, list, or dict, but got ''. Please make sure that the output rows are of the correct type." ] }, "UDTF_RETURN_NOT_ITERABLE" : { "message" : [ - "The return value of the UDTF is invalid. It should be an iterable (e.g., generator or list), but got ''. Please make sure that the UDTF returns one of these types." + "The return value of the '' method of the UDTF is invalid. It should be an iterable (e.g., generator or list), but got ''. Please make sure that the UDTF returns one of these types." ] }, "UDTF_RETURN_SCHEMA_MISMATCH" : { "message" : [ - "The number of columns in the result does not match the specified schema. Expected column count: , Actual column count: . Please make sure the values returned by the function have the same number of columns as specified in the output schema." + "The number of columns in the result does not match the specified schema. Expected column count: , Actual column count: . Please make sure the values returned by the '' method have the same number of columns as specified in the output schema." ] }, "UDTF_RETURN_TYPE_MISMATCH" : { diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index c5f8b7693c26d..97d5190a5060c 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -190,6 +190,27 @@ def eval(self, a): with self.assertRaisesRegex(PythonException, "UDTF_RETURN_NOT_ITERABLE"): TestUDTF(lit(1)).collect() + def test_udtf_with_zero_arg_and_invalid_return_value(self): + @udtf(returnType="x: int") + class TestUDTF: + def eval(self): + return 1 + + with self.assertRaisesRegex(PythonException, "UDTF_RETURN_NOT_ITERABLE"): + TestUDTF().collect() + + def test_udtf_with_invalid_return_value_in_terminate(self): + @udtf(returnType="x: int") + class TestUDTF: + def eval(self, a): + ... + + def terminate(self): + return 1 + + with self.assertRaisesRegex(PythonException, "UDTF_RETURN_NOT_ITERABLE"): + TestUDTF(lit(1)).collect() + def test_udtf_eval_with_no_return(self): @udtf(returnType="a: int") class TestUDTF: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index d95a5c4672f86..fff99f1de3d06 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -773,6 +773,7 @@ def verify_result(result): message_parameters={ "type_name": type(result).__name__, "value": str(result), + "func": f.__name__, }, ) @@ -787,6 +788,7 @@ def verify_result(result): message_parameters={ "expected": str(return_type_size), "actual": str(len(result.columns)), + "func": f.__name__, }, ) @@ -806,9 +808,23 @@ def func(*a: Any, **kw: Any) -> Any: message_parameters={"method_name": f.__name__, "error": str(e)}, ) + def check_return_value(res): + # Check whether the result of an arrow UDTF is iterable before + # using it to construct a pandas DataFrame. + if res is not None and not isinstance(res, Iterable): + raise PySparkRuntimeError( + error_class="UDTF_RETURN_NOT_ITERABLE", + message_parameters={ + "type": type(res).__name__, + "func": f.__name__, + }, + ) + def evaluate(*args: pd.Series, **kwargs: pd.Series): if len(args) == 0 and len(kwargs) == 0: - yield verify_result(pd.DataFrame(func())), arrow_return_type + res = func() + check_return_value(res) + yield verify_result(pd.DataFrame(res)), arrow_return_type else: # Create tuples from the input pandas Series, each tuple # represents a row across all Series. @@ -820,13 +836,7 @@ def evaluate(*args: pd.Series, **kwargs: pd.Series): *row[:len_args], **{key: row[len_args + i] for i, key in enumerate(keys)}, ) - if res is not None and not isinstance(res, Iterable): - raise PySparkRuntimeError( - error_class="UDTF_RETURN_NOT_ITERABLE", - message_parameters={ - "type": type(res).__name__, - }, - ) + check_return_value(res) yield verify_result(pd.DataFrame(res)), arrow_return_type return evaluate @@ -868,13 +878,17 @@ def verify_and_convert_result(result): message_parameters={ "expected": str(return_type_size), "actual": str(len(result)), + "func": f.__name__, }, ) if not (isinstance(result, (list, dict, tuple)) or hasattr(result, "__dict__")): raise PySparkRuntimeError( error_class="UDTF_INVALID_OUTPUT_ROW_TYPE", - message_parameters={"type": type(result).__name__}, + message_parameters={ + "type": type(result).__name__, + "func": f.__name__, + }, ) return toInternal(result) @@ -898,7 +912,10 @@ def evaluate(*a, **kw) -> tuple: if not isinstance(res, Iterable): raise PySparkRuntimeError( error_class="UDTF_RETURN_NOT_ITERABLE", - message_parameters={"type": type(res).__name__}, + message_parameters={ + "type": type(res).__name__, + "func": f.__name__, + }, ) # If the function returns a result, we map it to the internal representation and From 7a01ba65b7408bc3b907aa7b0b27279913caafe9 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Mon, 4 Sep 2023 09:36:49 +0900 Subject: [PATCH 26/35] [SPARK-45061][SS][CONNECT] Clean up Running python StreamingQueryLIstener processes when session expires ### What changes were proposed in this pull request? Clean up all running python StreamingQueryLIstener processes when session expires ### Why are the changes needed? Improvement ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Test will be added in SPARK-44462. Currently there is no way to test this because the session will never expire. This is because the started python listener process (on the server) will establish a connection with the server process with the same session id and ping it all the time. ### Was this patch authored or co-authored using generative AI tooling? No Closes #42687 from WweiL/SPARK-44433-followup-listener-cleanup. Authored-by: Wei Liu Signed-off-by: Hyukjin Kwon --- .../connect/planner/SparkConnectPlanner.scala | 4 +++- .../sql/connect/service/SessionHolder.scala | 21 ++++++++++++++----- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 11300631491d9..579b378d09f65 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -2900,7 +2900,9 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { SparkConnectService.streamingSessionManager.registerNewStreamingQuery(sessionHolder, query) // Register the runner with the query if Python foreachBatch is enabled. foreachBatchRunnerCleaner.foreach { cleaner => - sessionHolder.streamingRunnerCleanerCache.registerCleanerForQuery(query, cleaner) + sessionHolder.streamingForeachBatchRunnerCleanerCache.registerCleanerForQuery( + query, + cleaner) } executeHolder.eventsManager.postFinished() diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 2034a97fce940..1cef02d7e3466 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -57,7 +57,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio new ConcurrentHashMap() // Handles Python process clean up for streaming queries. Initialized on first use in a query. - private[connect] lazy val streamingRunnerCleanerCache = + private[connect] lazy val streamingForeachBatchRunnerCleanerCache = new StreamingForeachBatchHelper.CleanerCache(this) /** Add ExecuteHolder to this session. Called only by SparkConnectExecutionManager. */ @@ -160,7 +160,8 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio eventManager.postClosed() // Clean up running queries SparkConnectService.streamingSessionManager.cleanupRunningQueries(this) - streamingRunnerCleanerCache.cleanUpAll() // Clean up any streaming workers. + streamingForeachBatchRunnerCleanerCache.cleanUpAll() // Clean up any streaming workers. + removeAllListeners() // removes all listener and stop python listener processes if necessary. } /** @@ -237,11 +238,21 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio * Spark Connect PythonStreamingQueryListener. */ private[connect] def removeCachedListener(id: String): Unit = { - listenerCache.get(id) match { - case pyListener: PythonStreamingQueryListener => pyListener.stopListenerProcess() + Option(listenerCache.remove(id)) match { + case Some(pyListener: PythonStreamingQueryListener) => pyListener.stopListenerProcess() case _ => // do nothing } - listenerCache.remove(id) + } + + /** + * Stop all streaming listener threads, and removes all python process if applicable. Only + * called when session is expired. + */ + private def removeAllListeners(): Unit = { + listenerCache.forEach((id, listener) => { + session.streams.removeListener(listener) + removeCachedListener(id) + }) } /** From 79cc7f838f01879d63e7d249591ce7079f11801a Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Mon, 4 Sep 2023 09:42:21 +0900 Subject: [PATCH 27/35] [SPARK-45032][CONNECT] Fix compilation warnings related to `Top-level wildcard is not allowed and will error under -Xsource:3` ### What changes were proposed in this pull request? Build with Scala 2.13, will result in the following compilation warnings: ``` [warn] /Users/yangjie01/SourceCode/git/spark-mine-sbt/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala:175:32: [deprecation | origin= | version=2.13.7] Top-level wildcard is not allowed and will error under -Xsource:3 [warn] def removeGrpcResponseSender[_](sender: ExecuteGrpcResponseSender[_]): Unit = synchronized { [warn] ^ ``` So this pr fix it. ### Why are the changes needed? Fix compilation warnings related to `Top-level wildcard is not allowed and will error under -Xsource:3`. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass GitHub Actions ### Was this patch authored or co-authored using generative AI tooling? No Closes #42753 from LuciferYang/SPARK-45032. Authored-by: yangjie01 Signed-off-by: Hyukjin Kwon --- .../org/apache/spark/sql/connect/service/ExecuteHolder.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala index bce0713339228..fb22893598505 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala @@ -172,7 +172,7 @@ private[connect] class ExecuteHolder( } } - def removeGrpcResponseSender[_](sender: ExecuteGrpcResponseSender[_]): Unit = synchronized { + def removeGrpcResponseSender(sender: ExecuteGrpcResponseSender[_]): Unit = synchronized { // if closed, we are shutting down and interrupting all senders already if (closedTime.isEmpty) { grpcResponseSenders -= From 3ca57ae7a9bc2053807e0d0f04c59104037137e4 Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Mon, 4 Sep 2023 09:43:17 +0900 Subject: [PATCH 28/35] [SPARK-45038][PYTHON][DOCS] Refine docstring of `max` ### What changes were proposed in this pull request? This PR refines the docstring for function `max` by adding more examples. ### Why are the changes needed? To improve PySpark documentations. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? doctest ### Was this patch authored or co-authored using generative AI tooling? No Closes #42758 from allisonwang-db/spark-45038-refine-max. Authored-by: allisonwang-db Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/functions.py | 78 +++++++++++++++++++++++++++++++-- 1 file changed, 75 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index fb02cb0cc98b4..47d928fe59a90 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1217,22 +1217,94 @@ def max(col: "ColumnOrName") -> Column: Parameters ---------- col : :class:`~pyspark.sql.Column` or str - target column to compute on. + The target column on which the maximum value is computed. Returns ------- :class:`~pyspark.sql.Column` - column for computed results. + A column that contains the maximum value computed. + + See Also + -------- + :meth:`pyspark.sql.functions.min` + :meth:`pyspark.sql.functions.avg` + :meth:`pyspark.sql.functions.sum` + + Notes + ----- + - Null values are ignored during the computation. + - NaN values are larger than any other numeric value. Examples -------- + Example 1: Compute the maximum value of a numeric column + + >>> import pyspark.sql.functions as sf >>> df = spark.range(10) - >>> df.select(max(col("id"))).show() + >>> df.select(sf.max(df.id)).show() +-------+ |max(id)| +-------+ | 9| +-------+ + + Example 2: Compute the maximum value of a string column + + >>> import pyspark.sql.functions as sf + >>> df = spark.createDataFrame([("A",), ("B",), ("C",)], ["value"]) + >>> df.select(sf.max(df.value)).show() + +----------+ + |max(value)| + +----------+ + | C| + +----------+ + + Example 3: Compute the maximum value of a column in a grouped DataFrame + + >>> import pyspark.sql.functions as sf + >>> df = spark.createDataFrame([("A", 1), ("A", 2), ("B", 3), ("B", 4)], ["key", "value"]) + >>> df.groupBy("key").agg(sf.max(df.value)).show() + +---+----------+ + |key|max(value)| + +---+----------+ + | A| 2| + | B| 4| + +---+----------+ + + Example 4: Compute the maximum value of multiple columns in a grouped DataFrame + + >>> import pyspark.sql.functions as sf + >>> df = spark.createDataFrame( + ... [("A", 1, 2), ("A", 2, 3), ("B", 3, 4), ("B", 4, 5)], ["key", "value1", "value2"]) + >>> df.groupBy("key").agg(sf.max("value1"), sf.max("value2")).show() + +---+-----------+-----------+ + |key|max(value1)|max(value2)| + +---+-----------+-----------+ + | A| 2| 3| + | B| 4| 5| + +---+-----------+-----------+ + + Example 5: Compute the maximum value of a column with null values + + >>> import pyspark.sql.functions as sf + >>> df = spark.createDataFrame([(1,), (2,), (None,)], ["value"]) + >>> df.select(sf.max(df.value)).show() + +----------+ + |max(value)| + +----------+ + | 2| + +----------+ + + Example 6: Compute the maximum value of a column with "NaN" values + + >>> import pyspark.sql.functions as sf + >>> df = spark.createDataFrame([(1.1,), (float("nan"),), (3.3,)], ["value"]) + >>> df.select(sf.max(df.value)).show() + +----------+ + |max(value)| + +----------+ + | NaN| + +----------+ """ return _invoke_function_over_columns("max", col) From 6d1be3a389e407cb1422fa50e4b29270415d8fe5 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 4 Sep 2023 09:45:49 +0900 Subject: [PATCH 29/35] [SPARK-44667][INFRA][FOLLOWUP] Uninstall `deepspeed` libraries for non-ML jobs ### What changes were proposed in this pull request? Uninstall `deepspeed` libraries for non-ML jobs ### Why are the changes needed? after https://github.com/apache/spark/pull/42334, `deepspeed` was introduced as a dependency, which is not needed for non-ML jobs ### Does this PR introduce _any_ user-facing change? no, dev-only ### How was this patch tested? CI ### Was this patch authored or co-authored using generative AI tooling? NO Closes #42768 from zhengruifeng/uninstall_deepspeed. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- .github/workflows/build_and_test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index beb5a7772b7f8..f0bd65bcf415c 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -425,7 +425,7 @@ jobs: run: | if [[ "$MODULES_TO_TEST" != *"pyspark-ml"* ]] && [[ "$BRANCH" != "branch-3.5" ]]; then # uninstall libraries dedicated for ML testing - python3.9 -m pip uninstall -y torch torchvision torcheval torchtnt tensorboard mlflow + python3.9 -m pip uninstall -y torch torchvision torcheval torchtnt tensorboard mlflow deepspeed fi if [ -f ./dev/free_disk_space_container ]; then ./dev/free_disk_space_container From 09802865fd75947142611e0b73f3c6fa072640ee Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Mon, 4 Sep 2023 09:50:54 +0900 Subject: [PATCH 30/35] [SPARK-45053][PYTHON][MINOR] Log improvement in python version mismatch ### What changes were proposed in this pull request? Before: ``` pyspark.errors.exceptions.base.PySparkRuntimeError: [PYTHON_VERSION_MISMATCH] Python in worker has different version (3, 9) than that in driver 3.10, PySpark cannot run with different minor versions. ``` After: ``` pyspark.errors.exceptions.base.PySparkRuntimeError: [PYTHON_VERSION_MISMATCH] Python in worker has different version: 3.9 than that in driver: 3.10, PySpark cannot run with different minor versions. ``` ### Why are the changes needed? A little more easier to understand the error ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? No need ### Was this patch authored or co-authored using generative AI tooling? No Closes #42776 from WweiL/SPARK-45053-minor-log-improve. Authored-by: Wei Liu Signed-off-by: Hyukjin Kwon --- python/pyspark/errors/error_classes.py | 2 +- python/pyspark/worker_util.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index 74f52c416e95b..c98e9feb610fa 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -644,7 +644,7 @@ }, "PYTHON_VERSION_MISMATCH" : { "message" : [ - "Python in worker has different version than that in driver , PySpark cannot run with different minor versions.", + "Python in worker has different version: than that in driver: , PySpark cannot run with different minor versions.", "Please check environment variables PYSPARK_PYTHON and PYSPARK_DRIVER_PYTHON are correctly set." ] }, diff --git a/python/pyspark/worker_util.py b/python/pyspark/worker_util.py index 9f6d46c6211d5..722713b6f5422 100644 --- a/python/pyspark/worker_util.py +++ b/python/pyspark/worker_util.py @@ -70,11 +70,12 @@ def check_python_version(infile: IO) -> None: Check the Python version between the running process and the one used to serialize the command. """ version = utf8_deserializer.loads(infile) - if version != "%d.%d" % sys.version_info[:2]: + worker_version = "%d.%d" % sys.version_info[:2] + if version != worker_version: raise PySparkRuntimeError( error_class="PYTHON_VERSION_MISMATCH", message_parameters={ - "worker_version": str(sys.version_info[:2]), + "worker_version": worker_version, "driver_version": str(version), }, ) From f258af5a98b8f6fc9c338fb0fefb5aff751142a1 Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Mon, 4 Sep 2023 09:53:40 +0900 Subject: [PATCH 31/35] [SPARK-45058][PYTHON][DOCS] Refine docstring of DataFrame.distinct ### What changes were proposed in this pull request? This PR refines the docstring of `DataFrame.distinct` by adding more examples. ### Why are the changes needed? To improve PySpark documentations. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? doctest ### Was this patch authored or co-authored using generative AI tooling? No Closes #42782 from allisonwang-db/spark-45058-refine-distinct. Authored-by: allisonwang-db Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/dataframe.py | 77 ++++++++++++++++++++++++++++++++- 1 file changed, 76 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 42d85b82e9e21..64592311a1326 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1934,15 +1934,90 @@ def distinct(self) -> "DataFrame": :class:`DataFrame` DataFrame with distinct records. + See Also + -------- + DataFrame.dropDuplicates + Examples -------- + Remove duplicate rows from a DataFrame + >>> df = spark.createDataFrame( ... [(14, "Tom"), (23, "Alice"), (23, "Alice")], ["age", "name"]) + >>> df.distinct().show() + +---+-----+ + |age| name| + +---+-----+ + | 14| Tom| + | 23|Alice| + +---+-----+ - Return the number of distinct rows in the :class:`DataFrame` + Count the number of distinct rows in a DataFrame >>> df.distinct().count() 2 + + Get distinct rows from a DataFrame with multiple columns + + >>> df = spark.createDataFrame( + ... [(14, "Tom", "M"), (23, "Alice", "F"), (23, "Alice", "F"), (14, "Tom", "M")], + ... ["age", "name", "gender"]) + >>> df.distinct().show() + +---+-----+------+ + |age| name|gender| + +---+-----+------+ + | 14| Tom| M| + | 23|Alice| F| + +---+-----+------+ + + Get distinct values from a specific column in a DataFrame + + >>> df.select("name").distinct().show() + +-----+ + | name| + +-----+ + | Tom| + |Alice| + +-----+ + + Count the number of distinct values in a specific column + + >>> df.select("name").distinct().count() + 2 + + Get distinct values from multiple columns in DataFrame + + >>> df.select("name", "gender").distinct().show() + +-----+------+ + | name|gender| + +-----+------+ + | Tom| M| + |Alice| F| + +-----+------+ + + Get distinct rows from a DataFrame with null values + + >>> df = spark.createDataFrame( + ... [(14, "Tom", "M"), (23, "Alice", "F"), (23, "Alice", "F"), (14, "Tom", None)], + ... ["age", "name", "gender"]) + >>> df.distinct().show() + +---+-----+------+ + |age| name|gender| + +---+-----+------+ + | 14| Tom| M| + | 23|Alice| F| + | 14| Tom| NULL| + +---+-----+------+ + + Get distinct non-null values from a DataFrame + + >>> df.distinct().filter(df.gender.isNotNull()).show() + +---+-----+------+ + |age| name|gender| + +---+-----+------+ + | 14| Tom| M| + | 23|Alice| F| + +---+-----+------+ """ return DataFrame(self._jdf.distinct(), self.sparkSession) From 5b609598503df603cbddd5e1adf8d2cb28a5f977 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 4 Sep 2023 09:50:51 +0800 Subject: [PATCH 32/35] [SPARK-45052][SQL][PYTHON][CONNECT] Make function aliases output column name consistent with SQL ### What changes were proposed in this pull request? for SQL function aliases (with `setAlias=true`) added in 3.5, replace expression construction with `call_function` to make the output column name consistent with SQL Note: 1. we still have multiple similar cases added in other versions, this PR mainly focus on 3.5 to make it easy to backport if needed; 2. not all function aliases have the issue (e.g. `current_schema`), to be conservative, I change them all; ### Why are the changes needed? before this PR ``` scala> val df = spark.range(0, 10) df: org.apache.spark.sql.Dataset[Long] = [id: bigint] scala> df.createOrReplaceTempView("t") scala> spark.sql("SELECT TRY_SUM(id), TRY_AVG(id) FROM t") res1: org.apache.spark.sql.DataFrame = [try_sum(id): bigint, try_avg(id): double] scala> df.select(try_sum(col("id")), try_avg(col("id"))) res2: org.apache.spark.sql.DataFrame = [sum(id): bigint, avg(id): double] scala> scala> spark.sql("SELECT sign(-1), signum(-1)") res3: org.apache.spark.sql.DataFrame = [sign(-1): double, SIGNUM(-1): double] scala> spark.range(1).select(sign(lit(-1)), signum(lit(-1))) res4: org.apache.spark.sql.DataFrame = [SIGNUM(-1): double, SIGNUM(-1): double] ``` after this PR ``` scala> spark.sql("SELECT TRY_SUM(id), TRY_AVG(id) FROM t") res9: org.apache.spark.sql.DataFrame = [try_sum(id): bigint, try_avg(id): double] scala> df.select(try_sum(col("id")), try_avg(col("id"))) res10: org.apache.spark.sql.DataFrame = [try_sum(id): bigint, try_avg(id): double] scala> scala> spark.sql("SELECT sign(-1), signum(-1)") res11: org.apache.spark.sql.DataFrame = [sign(-1): double, SIGNUM(-1): double] scala> spark.range(1).select(sign(lit(-1)), signum(lit(-1))) res12: org.apache.spark.sql.DataFrame = [sign(-1): double, SIGNUM(-1): double] ``` ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? updated UT ### Was this patch authored or co-authored using generative AI tooling? no Closes #42775 from zhengruifeng/try_column_name. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- .../org/apache/spark/sql/functions.scala | 12 +- .../explain-results/describe.explain | 2 +- .../explain-results/function_ceiling.explain | 2 +- .../function_ceiling_scale.explain | 2 +- .../explain-results/function_printf.explain | 2 +- .../explain-results/function_sign.explain | 2 +- .../explain-results/function_std.explain | 2 +- .../query-tests/queries/function_ceiling.json | 2 +- .../queries/function_ceiling.proto.bin | Bin 173 -> 176 bytes .../queries/function_ceiling_scale.json | 2 +- .../queries/function_ceiling_scale.proto.bin | Bin 179 -> 182 bytes .../query-tests/queries/function_printf.json | 2 +- .../queries/function_printf.proto.bin | Bin 196 -> 189 bytes .../query-tests/queries/function_sign.json | 2 +- .../queries/function_sign.proto.bin | Bin 175 -> 173 bytes .../query-tests/queries/function_std.json | 2 +- .../queries/function_std.proto.bin | Bin 175 -> 172 bytes python/pyspark/sql/connect/functions.py | 26 +- python/pyspark/sql/functions.py | 716 ++++++++++++------ .../org/apache/spark/sql/functions.scala | 202 +++-- 20 files changed, 624 insertions(+), 354 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala index 8ea5f07c528f7..baafdd4e17222 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -987,7 +987,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def std(e: Column): Column = stddev(e) + def std(e: Column): Column = Column.fn("std", e) /** * Aggregate function: alias for `stddev_samp`. @@ -2337,7 +2337,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def ceiling(e: Column, scale: Column): Column = ceil(e, scale) + def ceiling(e: Column, scale: Column): Column = Column.fn("ceiling", e, scale) /** * Computes the ceiling of the given value of `e` to 0 decimal places. @@ -2345,7 +2345,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def ceiling(e: Column): Column = ceil(e) + def ceiling(e: Column): Column = Column.fn("ceiling", e) /** * Convert a number in a string column from one base to another. @@ -2800,7 +2800,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def power(l: Column, r: Column): Column = pow(l, r) + def power(l: Column, r: Column): Column = Column.fn("power", l, r) /** * Returns the positive value of dividend mod divisor. @@ -2937,7 +2937,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def sign(e: Column): Column = signum(e) + def sign(e: Column): Column = Column.fn("sign", e) /** * Computes the signum of the given value. @@ -4428,7 +4428,7 @@ object functions { * @since 3.5.0 */ def printf(format: Column, arguments: Column*): Column = - Column.fn("format_string", lit(format) +: arguments: _*) + Column.fn("printf", (format +: arguments): _*) /** * Decodes a `str` in 'application/x-www-form-urlencoded' format using a specific encoding diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/describe.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/describe.explain index f205f7ef7a140..b203f715c71a6 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/describe.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/describe.explain @@ -1,6 +1,6 @@ Project [summary#0, element_at(id#0, summary#0, None, false) AS id#0, element_at(b#0, summary#0, None, false) AS b#0] +- Project [id#0, b#0, summary#0] +- Generate explode([count,mean,stddev,min,max]), false, [summary#0] - +- Aggregate [map(cast(count as string), cast(count(id#0L) as string), cast(mean as string), cast(avg(id#0L) as string), cast(stddev as string), cast(stddev_samp(cast(id#0L as double)) as string), cast(min as string), cast(min(id#0L) as string), cast(max as string), cast(max(id#0L) as string)) AS id#0, map(cast(count as string), cast(count(b#0) as string), cast(mean as string), cast(avg(b#0) as string), cast(stddev as string), cast(stddev_samp(b#0) as string), cast(min as string), cast(min(b#0) as string), cast(max as string), cast(max(b#0) as string)) AS b#0] + +- Aggregate [map(cast(count as string), cast(count(id#0L) as string), cast(mean as string), cast(avg(id#0L) as string), cast(stddev as string), cast(stddev(cast(id#0L as double)) as string), cast(min as string), cast(min(id#0L) as string), cast(max as string), cast(max(id#0L) as string)) AS id#0, map(cast(count as string), cast(count(b#0) as string), cast(mean as string), cast(avg(b#0) as string), cast(stddev as string), cast(stddev(b#0) as string), cast(min as string), cast(min(b#0) as string), cast(max as string), cast(max(b#0) as string)) AS b#0] +- Project [id#0L, b#0] +- LocalRelation , [id#0L, a#0, b#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_ceiling.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_ceiling.explain index 9cf776a8dbaa7..217d7434b8020 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_ceiling.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_ceiling.explain @@ -1,2 +1,2 @@ -Project [CEIL(b#0) AS CEIL(b)#0L] +Project [ceiling(b#0) AS ceiling(b)#0L] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_ceiling_scale.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_ceiling_scale.explain index cdf8d356e47dd..2c41c12278bad 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_ceiling_scale.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_ceiling_scale.explain @@ -1,2 +1,2 @@ -Project [ceil(cast(b#0 as decimal(30,15)), 2) AS ceil(b, 2)#0] +Project [ceiling(cast(b#0 as decimal(30,15)), 2) AS ceiling(b, 2)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_printf.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_printf.explain index 10409df007070..8d55d77340002 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_printf.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_printf.explain @@ -1,2 +1,2 @@ -Project [format_string(g#0, a#0, g#0) AS format_string(g, a, g)#0] +Project [printf(g#0, a#0, g#0) AS printf(g, a, g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_sign.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_sign.explain index 807fa3300836c..5d41e16b6cef4 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_sign.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_sign.explain @@ -1,2 +1,2 @@ -Project [SIGNUM(b#0) AS SIGNUM(b)#0] +Project [sign(b#0) AS sign(b)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_std.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_std.explain index 106191e5a32ec..cf5b86ae3a571 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_std.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_std.explain @@ -1,2 +1,2 @@ -Aggregate [stddev(cast(a#0 as double)) AS stddev(a)#0] +Aggregate [std(cast(a#0 as double)) AS std(a)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_ceiling.json b/connector/connect/common/src/test/resources/query-tests/queries/function_ceiling.json index 5a9961ab47f55..99726305e8524 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/function_ceiling.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_ceiling.json @@ -13,7 +13,7 @@ }, "expressions": [{ "unresolvedFunction": { - "functionName": "ceil", + "functionName": "ceiling", "arguments": [{ "unresolvedAttribute": { "unparsedIdentifier": "b" diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_ceiling.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_ceiling.proto.bin index 3761deb1663a2fba0fd8e7e69eafd1a4d0f83415..cc91ac246a57c37ca5553326f5c8e57b4af2cba5 100644 GIT binary patch delta 34 pcmZ3>xPg(4i%Eb{YUM<>>B2%%0$l9LshK&MdFeu|Ld;x@NdScs2S@+_ delta 31 mcmdnMxR#NPi%Eb{YS~1#=>q&xyj(2FshK%KtU}COj7b1!r3L!{ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_ceiling_scale.json b/connector/connect/common/src/test/resources/query-tests/queries/function_ceiling_scale.json index bda5e85924c30..c0b0742b12157 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/function_ceiling_scale.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_ceiling_scale.json @@ -13,7 +13,7 @@ }, "expressions": [{ "unresolvedFunction": { - "functionName": "ceil", + "functionName": "ceiling", "arguments": [{ "unresolvedAttribute": { "unparsedIdentifier": "b" diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_ceiling_scale.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_ceiling_scale.proto.bin index 8db402ac167e0b1a1908e94a935376d3124a3876..30efc42b9d2bcabddfdc0c6b80c08971c759feff 100644 GIT binary patch delta 40 vcmdnYxQ&sGi%Eb{YW+mE>5>vsVqEOWshK&MdFeu|Ld;x@NkS}KOa@E0+W%!dxuLshK%KtU}COj7dT)TucT`0FiG7a{vGU diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_printf.json b/connector/connect/common/src/test/resources/query-tests/queries/function_printf.json index dc7ca880c4b09..73ca595e8650b 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/function_printf.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_printf.json @@ -13,7 +13,7 @@ }, "expressions": [{ "unresolvedFunction": { - "functionName": "format_string", + "functionName": "printf", "arguments": [{ "unresolvedAttribute": { "unparsedIdentifier": "g" diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_printf.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_printf.proto.bin index 7ebdda6cac10d8646e13dd3853e1d444f41b1586..3fb3862f44d91262510d0c07615e0256348e4a8e 100644 GIT binary patch delta 29 lcmX@YxR;TQi%Eb{YTHD%>D=;CvRrHhMVWaeX%p9r0sv~&2mSy6 delta 54 zcmdnXc!ZISi%Eb{YVSn0>FR1yDqOs2`9-;jCGo{2MVWc&Laai}T#QgE5y}GqaM%l> diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_sign.json b/connector/connect/common/src/test/resources/query-tests/queries/function_sign.json index bcf6ad7eb174d..34451969078b0 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/function_sign.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_sign.json @@ -13,7 +13,7 @@ }, "expressions": [{ "unresolvedFunction": { - "functionName": "signum", + "functionName": "sign", "arguments": [{ "unresolvedAttribute": { "unparsedIdentifier": "b" diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_sign.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_sign.proto.bin index af52abfb7f25b34a0a6afd5283a80270576c357d..ff866c97303edefe0216166590c81a175aa505f1 100644 GIT binary patch delta 31 mcmZ3_xR#NPi%Eb{YS~1#=>q&xyj(2Bndx~#tU}COj7b1#Tm}vR delta 33 ocmZ3>xSo-Xi%Eb{YQ;ph=|X~1{9J6sndy0@xk9W$%v_8~0DOA}`Tzg` diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_std.json b/connector/connect/common/src/test/resources/query-tests/queries/function_std.json index 1403817886ca0..cbdb4ea9e5e83 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/function_std.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_std.json @@ -13,7 +13,7 @@ }, "expressions": [{ "unresolvedFunction": { - "functionName": "stddev", + "functionName": "std", "arguments": [{ "unresolvedAttribute": { "unparsedIdentifier": "a" diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_std.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_std.proto.bin index 8d214eea8e74e4e8a8bc6674175c755ce456c838..7e34b0427c23bb8aeb5810cbf532787a68aa8e57 100644 GIT binary patch delta 30 lcmZ3_xQ3C9i%Eb{YUxC_>HK_BJY3AhB`HFzLd;x@i2z{81)%@{ delta 33 ocmZ3(xSo-Xi%Eb{YQ;ph=|X~1{9J6sB`GPXWkRe%%v_9#0DKSz?*IS* diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index ae28eb9346bfd..d5d2cd1c5e93f 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -559,7 +559,11 @@ def ceil(col: "ColumnOrName") -> Column: ceil.__doc__ = pysparkfuncs.ceil.__doc__ -ceiling = ceil +def ceiling(col: "ColumnOrName") -> Column: + return _invoke_function_over_columns("ceiling", col) + + +ceiling.__doc__ = pysparkfuncs.ceiling.__doc__ def conv(col: "ColumnOrName", fromBase: int, toBase: int) -> Column: @@ -827,7 +831,11 @@ def signum(col: "ColumnOrName") -> Column: signum.__doc__ = pysparkfuncs.signum.__doc__ -sign = signum +def sign(col: "ColumnOrName") -> Column: + return _invoke_function_over_columns("sign", col) + + +sign.__doc__ = pysparkfuncs.sign.__doc__ def sin(col: "ColumnOrName") -> Column: @@ -1203,13 +1211,17 @@ def skewness(col: "ColumnOrName") -> Column: def stddev(col: "ColumnOrName") -> Column: - return stddev_samp(col) + return _invoke_function_over_columns("stddev", col) stddev.__doc__ = pysparkfuncs.stddev.__doc__ -std = stddev +def std(col: "ColumnOrName") -> Column: + return _invoke_function_over_columns("std", col) + + +std.__doc__ = pysparkfuncs.std.__doc__ def stddev_samp(col: "ColumnOrName") -> Column: @@ -1333,7 +1345,7 @@ def variance(col: "ColumnOrName") -> Column: def every(col: "ColumnOrName") -> Column: - return _invoke_function_over_columns("bool_and", col) + return _invoke_function_over_columns("every", col) every.__doc__ = pysparkfuncs.every.__doc__ @@ -1347,7 +1359,7 @@ def bool_and(col: "ColumnOrName") -> Column: def some(col: "ColumnOrName") -> Column: - return _invoke_function_over_columns("bool_or", col) + return _invoke_function_over_columns("some", col) some.__doc__ = pysparkfuncs.some.__doc__ @@ -2561,7 +2573,7 @@ def parse_url( def printf(format: "ColumnOrName", *cols: "ColumnOrName") -> Column: - return _invoke_function("printf", lit(format), *[_to_col(c) for c in cols]) + return _invoke_function("printf", _to_col(format), *[_to_col(c) for c in cols]) printf.__doc__ = pysparkfuncs.printf.__doc__ diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 47d928fe59a90..52707217bdafd 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -932,12 +932,12 @@ def try_avg(col: "ColumnOrName") -> Column: >>> import pyspark.sql.functions as sf >>> spark.createDataFrame( ... [(1982, 15), (1990, 2)], ["birth", "age"] - ... ).select(sf.try_avg("age").alias("age_avg")).show() - +-------+ - |age_avg| - +-------+ - | 8.5| - +-------+ + ... ).select(sf.try_avg("age")).show() + +------------+ + |try_avg(age)| + +------------+ + | 8.5| + +------------+ """ return _invoke_function_over_columns("try_avg", col) @@ -1122,12 +1122,12 @@ def try_sum(col: "ColumnOrName") -> Column: Examples -------- >>> import pyspark.sql.functions as sf - >>> spark.range(10).select(sf.try_sum("id").alias("sum")).show() - +---+ - |sum| - +---+ - | 45| - +---+ + >>> spark.range(10).select(sf.try_sum("id")).show() + +-----------+ + |try_sum(id)| + +-----------+ + | 45| + +-----------+ """ return _invoke_function_over_columns("try_sum", col) @@ -1948,7 +1948,37 @@ def ceil(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("ceil", col) -ceiling = ceil +@try_remote_functions +def ceiling(col: "ColumnOrName") -> Column: + """ + Computes the ceiling of the given value. + + .. versionadded:: 1.4.0 + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to compute on. + + Returns + ------- + :class:`~pyspark.sql.Column` + the column for computed results. + + Examples + -------- + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select(sf.ceil(sf.lit(-0.1))).show() + +----------+ + |CEIL(-0.1)| + +----------+ + | 0| + +----------+ + """ + return _invoke_function_over_columns("ceiling", col) @try_remote_functions @@ -2300,14 +2330,15 @@ def negative(col: "ColumnOrName") -> Column: Examples -------- - >>> spark.range(3).select(negative("id").alias("n")).show() - +---+ - | n| - +---+ - | 0| - | -1| - | -2| - +---+ + >>> import pyspark.sql.functions as sf + >>> spark.range(3).select(sf.negative("id")).show() + +------------+ + |negative(id)| + +------------+ + | 0| + | -1| + | -2| + +------------+ """ return _invoke_function_over_columns("negative", col) @@ -2457,25 +2488,54 @@ def signum(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.range(1) - >>> df.select(signum(lit(-5))).show() - +----------+ - |SIGNUM(-5)| - +----------+ - | -1.0| - +----------+ - - >>> df.select(signum(lit(6))).show() - +---------+ - |SIGNUM(6)| - +---------+ - | 1.0| - +---------+ + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select( + ... sf.signum(sf.lit(-5)), + ... sf.signum(sf.lit(6)) + ... ).show() + +----------+---------+ + |SIGNUM(-5)|SIGNUM(6)| + +----------+---------+ + | -1.0| 1.0| + +----------+---------+ """ return _invoke_function_over_columns("signum", col) -sign = signum +@try_remote_functions +def sign(col: "ColumnOrName") -> Column: + """ + Computes the signum of the given value. + + .. versionadded:: 1.4.0 + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to compute on. + + Returns + ------- + :class:`~pyspark.sql.Column` + the column for computed results. + + Examples + -------- + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select( + ... sf.sign(sf.lit(-5)), + ... sf.sign(sf.lit(6)) + ... ).show() + +--------+-------+ + |sign(-5)|sign(6)| + +--------+-------+ + | -1.0| 1.0| + +--------+-------+ + """ + return _invoke_function_over_columns("sign", col) @try_remote_functions @@ -2778,15 +2838,17 @@ def getbit(col: "ColumnOrName", pos: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([[1],[1],[2]], ["c"]) - >>> df.select(getbit("c", lit(1)).alias("d")).show() - +---+ - | d| - +---+ - | 0| - | 0| - | 1| - +---+ + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [[1], [1], [2]], ["c"] + ... ).select(sf.getbit("c", sf.lit(1))).show() + +------------+ + |getbit(c, 1)| + +------------+ + | 0| + | 0| + | 1| + +------------+ """ return _invoke_function_over_columns("getbit", col, pos) @@ -2983,14 +3045,45 @@ def stddev(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.range(6) - >>> df.select(stddev(df.id)).first() - Row(stddev_samp(id)=1.87082...) + >>> import pyspark.sql.functions as sf + >>> spark.range(6).select(sf.stddev("id")).show() + +------------------+ + | stddev(id)| + +------------------+ + |1.8708286933869...| + +------------------+ """ return _invoke_function_over_columns("stddev", col) -std = stddev +@try_remote_functions +def std(col: "ColumnOrName") -> Column: + """ + Aggregate function: alias for stddev_samp. + + .. versionadded:: 3.5.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to compute on. + + Returns + ------- + :class:`~pyspark.sql.Column` + standard deviation of given column. + + Examples + -------- + >>> import pyspark.sql.functions as sf + >>> spark.range(6).select(sf.std("id")).show() + +------------------+ + | std(id)| + +------------------+ + |1.8708286933869...| + +------------------+ + """ + return _invoke_function_over_columns("std", col) @try_remote_functions @@ -3016,9 +3109,13 @@ def stddev_samp(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.range(6) - >>> df.select(stddev_samp(df.id)).first() - Row(stddev_samp(id)=1.87082...) + >>> import pyspark.sql.functions as sf + >>> spark.range(6).select(sf.stddev_samp("id")).show() + +------------------+ + | stddev_samp(id)| + +------------------+ + |1.8708286933869...| + +------------------+ """ return _invoke_function_over_columns("stddev_samp", col) @@ -3046,9 +3143,13 @@ def stddev_pop(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.range(6) - >>> df.select(stddev_pop(df.id)).first() - Row(stddev_pop(id)=1.70782...) + >>> import pyspark.sql.functions as sf + >>> spark.range(6).select(sf.stddev_pop("id")).show() + +-----------------+ + | stddev_pop(id)| + +-----------------+ + |1.707825127659...| + +-----------------+ """ return _invoke_function_over_columns("stddev_pop", col) @@ -3448,27 +3549,35 @@ def every(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([[True], [True], [True]], ["flag"]) - >>> df.select(every("flag")).show() - +--------------+ - |bool_and(flag)| - +--------------+ - | true| - +--------------+ - >>> df = spark.createDataFrame([[True], [False], [True]], ["flag"]) - >>> df.select(every("flag")).show() - +--------------+ - |bool_and(flag)| - +--------------+ - | false| - +--------------+ - >>> df = spark.createDataFrame([[False], [False], [False]], ["flag"]) - >>> df.select(every("flag")).show() - +--------------+ - |bool_and(flag)| - +--------------+ - | false| - +--------------+ + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [[True], [True], [True]], ["flag"] + ... ).select(sf.every("flag")).show() + +-----------+ + |every(flag)| + +-----------+ + | true| + +-----------+ + + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [[True], [False], [True]], ["flag"] + ... ).select(sf.every("flag")).show() + +-----------+ + |every(flag)| + +-----------+ + | false| + +-----------+ + + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [[False], [False], [False]], ["flag"] + ... ).select(sf.every("flag")).show() + +-----------+ + |every(flag)| + +-----------+ + | false| + +-----------+ """ return _invoke_function_over_columns("every", col) @@ -3536,27 +3645,35 @@ def some(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([[True], [True], [True]], ["flag"]) - >>> df.select(some("flag")).show() - +-------------+ - |bool_or(flag)| - +-------------+ - | true| - +-------------+ - >>> df = spark.createDataFrame([[True], [False], [True]], ["flag"]) - >>> df.select(some("flag")).show() - +-------------+ - |bool_or(flag)| - +-------------+ - | true| - +-------------+ - >>> df = spark.createDataFrame([[False], [False], [False]], ["flag"]) - >>> df.select(some("flag")).show() - +-------------+ - |bool_or(flag)| - +-------------+ - | false| - +-------------+ + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [[True], [True], [True]], ["flag"] + ... ).select(sf.some("flag")).show() + +----------+ + |some(flag)| + +----------+ + | true| + +----------+ + + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [[True], [False], [True]], ["flag"] + ... ).select(sf.some("flag")).show() + +----------+ + |some(flag)| + +----------+ + | true| + +----------+ + + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [[False], [False], [False]], ["flag"] + ... ).select(sf.some("flag")).show() + +----------+ + |some(flag)| + +----------+ + | false| + +----------+ """ return _invoke_function_over_columns("some", col) @@ -5225,22 +5342,23 @@ def approx_percentile( Examples -------- - >>> key = (col("id") % 3).alias("key") - >>> value = (randn(42) + key * 10).alias("value") + >>> import pyspark.sql.functions as sf + >>> key = (sf.col("id") % 3).alias("key") + >>> value = (sf.randn(42) + key * 10).alias("value") >>> df = spark.range(0, 1000, 1, 1).select(key, value) >>> df.select( - ... approx_percentile("value", [0.25, 0.5, 0.75], 1000000).alias("quantiles") + ... sf.approx_percentile("value", [0.25, 0.5, 0.75], 1000000) ... ).printSchema() root - |-- quantiles: array (nullable = true) + |-- approx_percentile(value, array(0.25, 0.5, 0.75), 1000000): array (nullable = true) | |-- element: double (containsNull = false) >>> df.groupBy("key").agg( - ... approx_percentile("value", 0.5, lit(1000000)).alias("median") + ... sf.approx_percentile("value", 0.5, sf.lit(1000000)) ... ).printSchema() root |-- key: long (nullable = true) - |-- median: double (nullable = true) + |-- approx_percentile(value, 0.5, 1000000): double (nullable = true) """ sc = get_active_spark_context() @@ -6282,15 +6400,25 @@ def first_value(col: "ColumnOrName", ignoreNulls: Optional[Union[bool, Column]] Examples -------- - >>> df = spark.createDataFrame([(None, 1), - ... ("a", 2), - ... ("a", 3), - ... ("b", 8), - ... ("b", 2)], ["c1", "c2"]) - >>> df.select(first_value('c1').alias('a'), first_value('c2').alias('b')).collect() - [Row(a=None, b=1)] - >>> df.select(first_value('c1', True).alias('a'), first_value('c2', True).alias('b')).collect() - [Row(a='a', b=1)] + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [(None, 1), ("a", 2), ("a", 3), ("b", 8), ("b", 2)], ["a", "b"] + ... ).select(sf.first_value('a'), sf.first_value('b')).show() + +--------------+--------------+ + |first_value(a)|first_value(b)| + +--------------+--------------+ + | NULL| 1| + +--------------+--------------+ + + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [(None, 1), ("a", 2), ("a", 3), ("b", 8), ("b", 2)], ["a", "b"] + ... ).select(sf.first_value('a', True), sf.first_value('b', True)).show() + +--------------+--------------+ + |first_value(a)|first_value(b)| + +--------------+--------------+ + | a| 1| + +--------------+--------------+ """ if ignoreNulls is None: return _invoke_function_over_columns("first_value", col) @@ -6320,15 +6448,25 @@ def last_value(col: "ColumnOrName", ignoreNulls: Optional[Union[bool, Column]] = Examples -------- - >>> df = spark.createDataFrame([("a", 1), - ... ("a", 2), - ... ("a", 3), - ... ("b", 8), - ... (None, 2)], ["c1", "c2"]) - >>> df.select(last_value('c1').alias('a'), last_value('c2').alias('b')).collect() - [Row(a=None, b=2)] - >>> df.select(last_value('c1', True).alias('a'), last_value('c2', True).alias('b')).collect() - [Row(a='b', b=2)] + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [("a", 1), ("a", 2), ("a", 3), ("b", 8), (None, 2)], ["a", "b"] + ... ).select(sf.last_value('a'), sf.last_value('b')).show() + +-------------+-------------+ + |last_value(a)|last_value(b)| + +-------------+-------------+ + | NULL| 2| + +-------------+-------------+ + + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [("a", 1), ("a", 2), ("a", 3), ("b", 8), (None, 2)], ["a", "b"] + ... ).select(sf.last_value('a', True), sf.last_value('b', True)).show() + +-------------+-------------+ + |last_value(a)|last_value(b)| + +-------------+-------------+ + | b| 2| + +-------------+-------------+ """ if ignoreNulls is None: return _invoke_function_over_columns("last_value", col) @@ -6490,8 +6628,8 @@ def curdate() -> Column: Examples -------- - >>> df = spark.range(1) - >>> df.select(curdate()).show() # doctest: +SKIP + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select(sf.curdate()).show() # doctest: +SKIP +--------------+ |current_date()| +--------------+ @@ -7237,13 +7375,35 @@ def dateadd(start: "ColumnOrName", days: Union["ColumnOrName", int]) -> Column: Examples -------- - >>> df = spark.createDataFrame([('2015-04-08', 2,)], ['dt', 'add']) - >>> df.select(dateadd(df.dt, 1).alias('next_date')).collect() - [Row(next_date=datetime.date(2015, 4, 9))] - >>> df.select(dateadd(df.dt, df.add.cast('integer')).alias('next_date')).collect() - [Row(next_date=datetime.date(2015, 4, 10))] - >>> df.select(dateadd('dt', -1).alias('prev_date')).collect() - [Row(prev_date=datetime.date(2015, 4, 7))] + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [('2015-04-08', 2,)], ['dt', 'add'] + ... ).select(sf.dateadd("dt", 1)).show() + +---------------+ + |date_add(dt, 1)| + +---------------+ + | 2015-04-09| + +---------------+ + + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [('2015-04-08', 2,)], ['dt', 'add'] + ... ).select(sf.dateadd("dt", sf.lit(2))).show() + +---------------+ + |date_add(dt, 2)| + +---------------+ + | 2015-04-10| + +---------------+ + + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [('2015-04-08', 2,)], ['dt', 'add'] + ... ).select(sf.dateadd("dt", -1)).show() + +----------------+ + |date_add(dt, -1)| + +----------------+ + | 2015-04-07| + +----------------+ """ days = lit(days) if isinstance(days, int) else days return _invoke_function_over_columns("dateadd", start, days) @@ -7710,9 +7870,15 @@ def xpath_number(xml: "ColumnOrName", path: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([('12',)], ['x']) - >>> df.select(xpath_number(df.x, lit('sum(a/b)')).alias('r')).collect() - [Row(r=3.0)] + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [('12',)], ['x'] + ... ).select(sf.xpath_number('x', sf.lit('sum(a/b)'))).show() + +-------------------------+ + |xpath_number(x, sum(a/b))| + +-------------------------+ + | 3.0| + +-------------------------+ """ return _invoke_function_over_columns("xpath_number", xml, path) @@ -8599,7 +8765,8 @@ def current_schema() -> Column: Examples -------- - >>> spark.range(1).select(current_schema()).show() + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select(sf.current_schema()).show() +------------------+ |current_database()| +------------------+ @@ -8635,7 +8802,8 @@ def user() -> Column: Examples -------- - >>> spark.range(1).select(user()).show() # doctest: +SKIP + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select(sf.user()).show() # doctest: +SKIP +--------------+ |current_user()| +--------------+ @@ -9901,13 +10069,35 @@ def regexp(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]) - >>> df.select(regexp('str', lit(r'(\d+)')).alias('d')).collect() - [Row(d=True)] - >>> df.select(regexp('str', lit(r'\d{2}b')).alias('d')).collect() - [Row(d=False)] - >>> df.select(regexp("str", col("regexp")).alias('d')).collect() - [Row(d=True)] + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] + ... ).select(sf.regexp('str', sf.lit(r'(\d+)'))).show() + +------------------+ + |REGEXP(str, (\d+))| + +------------------+ + | true| + +------------------+ + + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] + ... ).select(sf.regexp('str', sf.lit(r'\d{2}b'))).show() + +-------------------+ + |REGEXP(str, \d{2}b)| + +-------------------+ + | false| + +-------------------+ + + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] + ... ).select(sf.regexp('str', sf.col("regexp"))).show() + +-------------------+ + |REGEXP(str, regexp)| + +-------------------+ + | true| + +-------------------+ """ return _invoke_function_over_columns("regexp", str, regexp) @@ -9932,13 +10122,35 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]) - >>> df.select(regexp_like('str', lit(r'(\d+)')).alias('d')).collect() - [Row(d=True)] - >>> df.select(regexp_like('str', lit(r'\d{2}b')).alias('d')).collect() - [Row(d=False)] - >>> df.select(regexp_like("str", col("regexp")).alias('d')).collect() - [Row(d=True)] + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] + ... ).select(sf.regexp_like('str', sf.lit(r'(\d+)'))).show() + +-----------------------+ + |REGEXP_LIKE(str, (\d+))| + +-----------------------+ + | true| + +-----------------------+ + + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] + ... ).select(sf.regexp_like('str', sf.lit(r'\d{2}b'))).show() + +------------------------+ + |REGEXP_LIKE(str, \d{2}b)| + +------------------------+ + | false| + +------------------------+ + + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] + ... ).select(sf.regexp_like('str', sf.col("regexp"))).show() + +------------------------+ + |REGEXP_LIKE(str, regexp)| + +------------------------+ + | true| + +------------------------+ """ return _invoke_function_over_columns("regexp_like", str, regexp) @@ -10679,12 +10891,25 @@ def substr( Examples -------- - >>> df = spark.createDataFrame([("Spark SQL", 5, 1,)], ["a", "b", "c"]) - >>> df.select(substr(df.a, df.b, df.c).alias('r')).collect() - [Row(r='k')] + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [("Spark SQL", 5, 1,)], ["a", "b", "c"] + ... ).select(sf.substr("a", "b", "c")).show() + +---------------+ + |substr(a, b, c)| + +---------------+ + | k| + +---------------+ - >>> df.select(substr(df.a, df.b).alias('r')).collect() - [Row(r='k SQL')] + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [("Spark SQL", 5, 1,)], ["a", "b", "c"] + ... ).select(sf.substr("a", "b")).show() + +------------------------+ + |substr(a, b, 2147483647)| + +------------------------+ + | k SQL| + +------------------------+ """ if len is not None: return _invoke_function_over_columns("substr", str, pos, len) @@ -10744,9 +10969,15 @@ def printf(format: "ColumnOrName", *cols: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([("aa%d%s", 123, "cc",)], ["a", "b", "c"]) - >>> df.select(printf(df.a, df.b, df.c).alias('r')).collect() - [Row(r='aa123cc')] + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [("aa%d%s", 123, "cc",)], ["a", "b", "c"] + ... ).select(sf.printf("a", "b", "c")).show() + +---------------+ + |printf(a, b, c)| + +---------------+ + | aa123cc| + +---------------+ """ sc = get_active_spark_context() return _invoke_function("printf", _to_java_column(format), _to_seq(sc, cols, _to_java_column)) @@ -10817,12 +11048,24 @@ def position( Examples -------- - >>> df = spark.createDataFrame([("bar", "foobarbar", 5,)], ["a", "b", "c"]) - >>> df.select(position(df.a, df.b, df.c).alias('r')).collect() - [Row(r=7)] + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [("bar", "foobarbar", 5,)], ["a", "b", "c"] + ... ).select(sf.position("a", "b", "c")).show() + +-----------------+ + |position(a, b, c)| + +-----------------+ + | 7| + +-----------------+ - >>> df.select(position(df.a, df.b).alias('r')).collect() - [Row(r=4)] + >>> spark.createDataFrame( + ... [("bar", "foobarbar", 5,)], ["a", "b", "c"] + ... ).select(sf.position("a", "b")).show() + +-----------------+ + |position(a, b, 1)| + +-----------------+ + | 4| + +-----------------+ """ if start is not None: return _invoke_function_over_columns("position", substr, str, start) @@ -10921,9 +11164,13 @@ def char(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([(65,)], ['a']) - >>> df.select(char(df.a).alias('r')).collect() - [Row(r='A')] + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select(sf.char(sf.lit(65))).show() + +--------+ + |char(65)| + +--------+ + | A| + +--------+ """ return _invoke_function_over_columns("char", col) @@ -10974,9 +11221,13 @@ def char_length(str: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([("SparkSQL",)], ['a']) - >>> df.select(char_length(df.a).alias('r')).collect() - [Row(r=8)] + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select(sf.char_length(sf.lit("SparkSQL"))).show() + +---------------------+ + |char_length(SparkSQL)| + +---------------------+ + | 8| + +---------------------+ """ return _invoke_function_over_columns("char_length", str) @@ -10997,9 +11248,13 @@ def character_length(str: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([("SparkSQL",)], ['a']) - >>> df.select(character_length(df.a).alias('r')).collect() - [Row(r=8)] + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select(sf.character_length(sf.lit("SparkSQL"))).show() + +--------------------------+ + |character_length(SparkSQL)| + +--------------------------+ + | 8| + +--------------------------+ """ return _invoke_function_over_columns("character_length", str) @@ -11262,9 +11517,13 @@ def lcase(str: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([("Spark",)], ['a']) - >>> df.select(lcase(df.a).alias('r')).collect() - [Row(r='spark')] + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select(sf.lcase(sf.lit("Spark"))).show() + +------------+ + |lcase(Spark)| + +------------+ + | spark| + +------------+ """ return _invoke_function_over_columns("lcase", str) @@ -11283,9 +11542,13 @@ def ucase(str: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([("Spark",)], ['a']) - >>> df.select(ucase(df.a).alias('r')).collect() - [Row(r='SPARK')] + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select(sf.ucase(sf.lit("Spark"))).show() + +------------+ + |ucase(Spark)| + +------------+ + | SPARK| + +------------+ """ return _invoke_function_over_columns("ucase", str) @@ -12991,9 +13254,17 @@ def cardinality(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([([1, 2, 3],),([1],),([],)], ['data']) - >>> df.select(cardinality(df.data).alias('r')).collect() - [Row(r=3), Row(r=1), Row(r=0)] + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [([1, 2, 3],),([1],),([],)], ['data'] + ... ).select(sf.cardinality("data")).show() + +-----------------+ + |cardinality(data)| + +-----------------+ + | 3| + | 1| + | 0| + +-----------------+ """ return _invoke_function_over_columns("cardinality", col) @@ -14850,26 +15121,27 @@ def make_timestamp_ltz( Examples -------- + >>> import pyspark.sql.functions as sf >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") >>> df = spark.createDataFrame([[2014, 12, 28, 6, 30, 45.887, 'CET']], ... ["year", "month", "day", "hour", "min", "sec", "timezone"]) - >>> df.select(make_timestamp_ltz( - ... df.year, df.month, df.day, df.hour, df.min, df.sec, df.timezone).alias('r') + >>> df.select(sf.make_timestamp_ltz( + ... df.year, df.month, df.day, df.hour, df.min, df.sec, df.timezone) ... ).show(truncate=False) - +-----------------------+ - |r | - +-----------------------+ - |2014-12-27 21:30:45.887| - +-----------------------+ - - >>> df.select(make_timestamp_ltz( - ... df.year, df.month, df.day, df.hour, df.min, df.sec).alias('r') + +--------------------------------------------------------------+ + |make_timestamp_ltz(year, month, day, hour, min, sec, timezone)| + +--------------------------------------------------------------+ + |2014-12-27 21:30:45.887 | + +--------------------------------------------------------------+ + + >>> df.select(sf.make_timestamp_ltz( + ... df.year, df.month, df.day, df.hour, df.min, df.sec) ... ).show(truncate=False) - +-----------------------+ - |r | - +-----------------------+ - |2014-12-28 06:30:45.887| - +-----------------------+ + +----------------------------------------------------+ + |make_timestamp_ltz(year, month, day, hour, min, sec)| + +----------------------------------------------------+ + |2014-12-28 06:30:45.887 | + +----------------------------------------------------+ >>> spark.conf.unset("spark.sql.session.timeZone") """ if timezone is not None: @@ -14918,17 +15190,18 @@ def make_timestamp_ntz( Examples -------- + >>> import pyspark.sql.functions as sf >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") >>> df = spark.createDataFrame([[2014, 12, 28, 6, 30, 45.887]], ... ["year", "month", "day", "hour", "min", "sec"]) - >>> df.select(make_timestamp_ntz( - ... df.year, df.month, df.day, df.hour, df.min, df.sec).alias('r') + >>> df.select(sf.make_timestamp_ntz( + ... df.year, df.month, df.day, df.hour, df.min, df.sec) ... ).show(truncate=False) - +-----------------------+ - |r | - +-----------------------+ - |2014-12-28 06:30:45.887| - +-----------------------+ + +----------------------------------------------------+ + |make_timestamp_ntz(year, month, day, hour, min, sec)| + +----------------------------------------------------+ + |2014-12-28 06:30:45.887 | + +----------------------------------------------------+ >>> spark.conf.unset("spark.sql.session.timeZone") """ return _invoke_function_over_columns( @@ -15362,9 +15635,15 @@ def ifnull(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: Examples -------- + >>> import pyspark.sql.functions as sf >>> df = spark.createDataFrame([(None,), (1,)], ["e"]) - >>> df.select(ifnull(df.e, lit(8)).alias('r')).collect() - [Row(r=8), Row(r=1)] + >>> df.select(sf.ifnull(df.e, sf.lit(8))).show() + +------------+ + |ifnull(e, 8)| + +------------+ + | 8| + | 1| + +------------+ """ return _invoke_function_over_columns("ifnull", col1, col2) @@ -15730,9 +16009,13 @@ def sha(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([("Spark",)], ["a"]) - >>> df.select(sha(df.a).alias('r')).collect() - [Row(r='85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c')] + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select(sf.sha(sf.lit("Spark"))).show() + +--------------------+ + | sha(Spark)| + +--------------------+ + |85f5955f4b27a9a4c...| + +--------------------+ """ return _invoke_function_over_columns("sha", col) @@ -15810,12 +16093,19 @@ def java_method(*cols: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([("a5cf6c42-0c85-418f-af6c-3e4e5b1328f2",)], ["a"]) - >>> df.select( - ... java_method(lit("java.util.UUID"), lit("fromString"), df.a).alias('r') - ... ).collect() - [Row(r='a5cf6c42-0c85-418f-af6c-3e4e5b1328f2')] - + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select( + ... sf.java_method( + ... sf.lit("java.util.UUID"), + ... sf.lit("fromString"), + ... sf.lit("a5cf6c42-0c85-418f-af6c-3e4e5b1328f2") + ... ) + ... ).show(truncate=False) + +-----------------------------------------------------------------------------+ + |java_method(java.util.UUID, fromString, a5cf6c42-0c85-418f-af6c-3e4e5b1328f2)| + +-----------------------------------------------------------------------------+ + |a5cf6c42-0c85-418f-af6c-3e4e5b1328f2 | + +-----------------------------------------------------------------------------+ """ return _invoke_function_over_seq_of_columns("java_method", cols) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 9548f424ad407..dcfe10f9a4d80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -575,7 +575,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def first_value(e: Column): Column = first(e) + def first_value(e: Column): Column = call_function("first_value", e) /** * Aggregate function: returns the first value in a group. @@ -589,9 +589,8 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def first_value(e: Column, ignoreNulls: Column): Column = withAggregateFunction { - new First(e.expr, ignoreNulls.expr) - } + def first_value(e: Column, ignoreNulls: Column): Column = + call_function("first_value", e, ignoreNulls) /** * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated @@ -848,7 +847,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def last_value(e: Column): Column = last(e) + def last_value(e: Column): Column = call_function("last_value", e) /** * Aggregate function: returns the last value in a group. @@ -862,9 +861,8 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def last_value(e: Column, ignoreNulls: Column): Column = withAggregateFunction { - new Last(e.expr, ignoreNulls.expr) - } + def last_value(e: Column, ignoreNulls: Column): Column = + call_function("last_value", e, ignoreNulls) /** * Aggregate function: returns the most frequent value in a group. @@ -1017,9 +1015,8 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def approx_percentile(e: Column, percentage: Column, accuracy: Column): Column = { - percentile_approx(e, percentage, accuracy) - } + def approx_percentile(e: Column, percentage: Column, accuracy: Column): Column = + call_function("approx_percentile", e, percentage, accuracy) /** * Aggregate function: returns the product of all numerical elements in a group. @@ -1052,7 +1049,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def std(e: Column): Column = stddev(e) + def std(e: Column): Column = call_function("std", e) /** * Aggregate function: alias for `stddev_samp`. @@ -1060,7 +1057,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def stddev(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) } + def stddev(e: Column): Column = call_function("stddev", e) /** * Aggregate function: alias for `stddev_samp`. @@ -1330,7 +1327,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def every(e: Column): Column = withAggregateFunction { BoolAnd(e.expr) } + def every(e: Column): Column = call_function("every", e) /** * Aggregate function: returns true if all values of `e` are true. @@ -1346,7 +1343,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def some(e: Column): Column = withAggregateFunction { BoolOr(e.expr) } + def some(e: Column): Column = call_function("some", e) /** * Aggregate function: returns true if at least one value of `e` is true. @@ -1354,7 +1351,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def any(e: Column): Column = withAggregateFunction { BoolOr(e.expr) } + def any(e: Column): Column = call_function("any", e) /** * Aggregate function: returns true if at least one value of `e` is true. @@ -1944,9 +1941,8 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def try_avg(e: Column): Column = withAggregateFunction { - Average(e.expr, EvalMode.TRY) - } + def try_avg(e: Column): Column = + call_function("try_avg", e) /** * Returns `dividend``/``divisor`. It always performs floating point division. Its result is @@ -1984,9 +1980,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def try_sum(e: Column): Column = withAggregateFunction { - Sum(e.expr, EvalMode.TRY) - } + def try_sum(e: Column): Column = call_function("try_sum", e) /** * Creates a new struct column. @@ -2081,7 +2075,7 @@ object functions { * @group bitwise_funcs * @since 3.5.0 */ - def getbit(e: Column, pos: Column): Column = bit_get(e, pos) + def getbit(e: Column, pos: Column): Column = call_function("getbit", e, pos) /** * Parses the expression string into the column that it represents, similar to @@ -2385,7 +2379,8 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def ceiling(e: Column, scale: Column): Column = ceil(e, scale) + def ceiling(e: Column, scale: Column): Column = + call_function("ceiling", e, scale) /** * Computes the ceiling of the given value of `e` to 0 decimal places. @@ -2393,7 +2388,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def ceiling(e: Column): Column = ceil(e) + def ceiling(e: Column): Column = call_function("ceiling", e) /** * Convert a number in a string column from one base to another. @@ -2751,7 +2746,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def negative(e: Column): Column = withExpr { UnaryMinus(e.expr) } + def negative(e: Column): Column = call_function("negative", e) /** * Returns Pi. @@ -2979,7 +2974,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def sign(e: Column): Column = signum(e) + def sign(e: Column): Column = call_function("sign", e) /** * Computes the signum of the given value. @@ -3184,7 +3179,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def current_schema(): Column = withExpr { CurrentDatabase() } + def current_schema(): Column = call_function("current_schema") /** * Returns the user name of current execution context. @@ -3368,7 +3363,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def user(): Column = withExpr { CurrentUser() } + def user(): Column = call_function("user") /** * Returns the user name of current execution context. @@ -3646,9 +3641,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def sha(col: Column): Column = withExpr { - Sha1(col.expr) - } + def sha(col: Column): Column = call_function("sha", col) /** * Returns the length of the block being read, or -1 if not available. @@ -3686,9 +3679,8 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def java_method(cols: Column*): Column = withExpr { - CallMethodViaReflection(cols.map(_.expr)) - } + def java_method(cols: Column*): Column = + call_function("java_method", cols: _*) /** * Calls a method with reflection. @@ -3739,9 +3731,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def random(seed: Column): Column = withExpr { - Rand(seed.expr) - } + def random(seed: Column): Column = call_function("random", seed) /** * Returns a random value with independent and identically distributed (i.i.d.) uniformly @@ -3750,9 +3740,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def random(): Column = withExpr { - new Rand() - } + def random(): Column = call_function("random") /** * Returns the bucket number for the given input column. @@ -4058,7 +4046,8 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def regexp(str: Column, regexp: Column): Column = rlike(str, regexp) + def regexp(str: Column, regexp: Column): Column = + call_function("regexp", str, regexp) /** * Returns true if `str` matches `regexp`, or false otherwise. @@ -4066,7 +4055,8 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def regexp_like(str: Column, regexp: Column): Column = rlike(str, regexp) + def regexp_like(str: Column, regexp: Column): Column = + call_function("regexp_like", str, regexp) /** * Returns a count of the number of times that the regular expression pattern `regexp` @@ -4532,9 +4522,8 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def substr(str: Column, pos: Column, len: Column): Column = withExpr { - Substring(str.expr, pos.expr, len.expr) - } + def substr(str: Column, pos: Column, len: Column): Column = + call_function("substr", str, pos, len) /** * Returns the substring of `str` that starts at `pos`, @@ -4543,9 +4532,8 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def substr(str: Column, pos: Column): Column = withExpr { - new Substring(str.expr, pos.expr) - } + def substr(str: Column, pos: Column): Column = + call_function("substr", str, pos) /** * Extracts a part from a URL. @@ -4573,9 +4561,8 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def printf(format: Column, arguments: Column*): Column = withExpr { - FormatString((lit(format) +: arguments).map(_.expr): _*) - } + def printf(format: Column, arguments: Column*): Column = + call_function("printf", (format +: arguments): _*) /** * Decodes a `str` in 'application/x-www-form-urlencoded' format @@ -4606,9 +4593,8 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def position(substr: Column, str: Column, start: Column): Column = withExpr { - StringLocate(substr.expr, str.expr, start.expr) - } + def position(substr: Column, str: Column, start: Column): Column = + call_function("position", substr, str, start) /** * Returns the position of the first occurrence of `substr` in `str` after position `1`. @@ -4617,9 +4603,8 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def position(substr: Column, str: Column): Column = withExpr { - new StringLocate(substr.expr, str.expr) - } + def position(substr: Column, str: Column): Column = + call_function("position", substr, str) /** * Returns a boolean. The value is True if str ends with suffix. @@ -4648,9 +4633,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def char(n: Column): Column = withExpr { - Chr(n.expr) - } + def char(n: Column): Column = call_function("char", n) /** * Removes the leading and trailing space characters from `str`. @@ -4714,9 +4697,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def char_length(str: Column): Column = withExpr { - Length(str.expr) - } + def char_length(str: Column): Column = call_function("char_length", str) /** * Returns the character length of string data or number of bytes of binary data. @@ -4726,9 +4707,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def character_length(str: Column): Column = withExpr { - Length(str.expr) - } + def character_length(str: Column): Column = call_function("character_length", str) /** * Returns the ASCII character having the binary equivalent to `n`. @@ -4837,9 +4816,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def lcase(str: Column): Column = withExpr { - Lower(str.expr) - } + def lcase(str: Column): Column = call_function("lcase", str) /** * Returns `str` with all characters changed to uppercase. @@ -4847,9 +4824,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def ucase(str: Column): Column = withExpr { - Upper(str.expr) - } + def ucase(str: Column): Column = call_function("ucase", str) /** * Returns the leftmost `len`(`len` can be string type) characters from the string `str`, @@ -4911,7 +4886,7 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def curdate(): Column = withExpr { CurrentDate() } + def curdate(): Column = call_function("curdate") /** * Returns the current date at the start of query evaluation as a date column. @@ -5013,7 +4988,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def dateadd(start: Column, days: Column): Column = date_add(start, days) + def dateadd(start: Column, days: Column): Column = + call_function("dateadd", start, days) /** * Returns the date that is `days` days before `start` @@ -5078,7 +5054,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def date_diff(end: Column, start: Column): Column = datediff(end, start) + def date_diff(end: Column, start: Column): Column = + call_function("date_diff", end, start) /** * Create date from the number of `days` since 1970-01-01. @@ -5135,7 +5112,7 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def day(e: Column): Column = dayofmonth(e) + def day(e: Column): Column = call_function("day", e) /** * Extracts the day of the year as an integer from a given date/timestamp/string. @@ -5174,7 +5151,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def date_part(field: Column, source: Column): Column = call_function("date_part", field, source) + def date_part(field: Column, source: Column): Column = + call_function("date_part", field, source) /** * Extracts a part of the date/timestamp or interval source. @@ -5186,7 +5164,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def datepart(field: Column, source: Column): Column = call_function("datepart", field, source) + def datepart(field: Column, source: Column): Column = + call_function("datepart", field, source) /** * Returns the last day of the month which the given date belongs to. @@ -5439,9 +5418,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def try_to_timestamp(s: Column, format: Column): Column = withExpr { - new ParseToTimestamp(s.expr, format.expr) - } + def try_to_timestamp(s: Column, format: Column): Column = + call_function("try_to_timestamp", s, format) /** * Parses the `s` to a timestamp. The function always returns null on an invalid @@ -5451,9 +5429,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def try_to_timestamp(s: Column): Column = withExpr { - new ParseToTimestamp(s.expr) - } + def try_to_timestamp(s: Column): Column = + call_function("try_to_timestamp", s) /** * Converts the column into `DateType` by casting rules to `DateType`. @@ -5890,9 +5867,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def to_timestamp_ltz(timestamp: Column, format: Column): Column = withExpr { - ParseToTimestamp(timestamp.expr, Some(format.expr), TimestampType) - } + def to_timestamp_ltz(timestamp: Column, format: Column): Column = + call_function("to_timestamp_ltz", timestamp, format) /** * Parses the `timestamp` expression with the default format to a timestamp without time zone. @@ -5901,9 +5877,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def to_timestamp_ltz(timestamp: Column): Column = withExpr { - ParseToTimestamp(timestamp.expr, None, TimestampType) - } + def to_timestamp_ltz(timestamp: Column): Column = + call_function("to_timestamp_ltz", timestamp) /** * Parses the `timestamp_str` expression with the `format` expression @@ -5912,9 +5887,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def to_timestamp_ntz(timestamp: Column, format: Column): Column = withExpr { - ParseToTimestamp(timestamp.expr, Some(format.expr), TimestampNTZType) - } + def to_timestamp_ntz(timestamp: Column, format: Column): Column = + call_function("to_timestamp_ntz", timestamp, format) /** * Parses the `timestamp` expression with the default format to a timestamp without time zone. @@ -5923,9 +5897,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def to_timestamp_ntz(timestamp: Column): Column = withExpr { - ParseToTimestamp(timestamp.expr, None, TimestampNTZType) - } + def to_timestamp_ntz(timestamp: Column): Column = + call_function("to_timestamp_ntz", timestamp) /** * Returns the UNIX timestamp of the given time. @@ -7030,7 +7003,7 @@ object functions { * @group collection_funcs * @since 3.5.0 */ - def cardinality(e: Column): Column = size(e) + def cardinality(e: Column): Column = call_function("cardinality", e) /** * Sorts the input array for the given column in ascending order, @@ -7088,7 +7061,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def array_agg(e: Column): Column = collect_list(e) + def array_agg(e: Column): Column = call_function("array_agg", e) /** * Returns a random permutation of the given array. @@ -7487,9 +7460,8 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath_number(x: Column, p: Column): Column = withExpr { - XPathDouble(x.expr, p.expr) - } + def xpath_number(x: Column, p: Column): Column = + call_function("xpath_number", x, p) /** * Returns a float value, the value zero if no match is found, @@ -7788,10 +7760,9 @@ object functions { hours: Column, mins: Column, secs: Column, - timezone: Column): Column = withExpr { - MakeTimestamp(years.expr, months.expr, days.expr, hours.expr, - mins.expr, secs.expr, Some(timezone.expr), dataType = TimestampType) - } + timezone: Column): Column = + call_function("make_timestamp_ltz", + years, months, days, hours, mins, secs, timezone) /** * Create the current timestamp with local time zone from years, months, days, hours, mins and @@ -7807,10 +7778,9 @@ object functions { days: Column, hours: Column, mins: Column, - secs: Column): Column = withExpr { - MakeTimestamp(years.expr, months.expr, days.expr, hours.expr, - mins.expr, secs.expr, dataType = TimestampType) - } + secs: Column): Column = + call_function("make_timestamp_ltz", + years, months, days, hours, mins, secs) /** * Create local date-time from years, months, days, hours, mins, secs fields. If the @@ -7826,10 +7796,9 @@ object functions { days: Column, hours: Column, mins: Column, - secs: Column): Column = withExpr { - MakeTimestamp(years.expr, months.expr, days.expr, hours.expr, - mins.expr, secs.expr, dataType = TimestampNTZType) - } + secs: Column): Column = + call_function("make_timestamp_ntz", + years, months, days, hours, mins, secs) /** * Make year-month interval from years, months. @@ -7896,9 +7865,8 @@ object functions { * @group predicates_funcs * @since 3.5.0 */ - def ifnull(col1: Column, col2: Column): Column = withExpr { - new Nvl(col1.expr, col2.expr) - } + def ifnull(col1: Column, col2: Column): Column = + call_function("ifnull", col1, col2) /** * Returns true if `col` is not null, or false otherwise. From cf0a5cb472efebb4350e48bd82a4f834e8607333 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 4 Sep 2023 11:41:48 +0900 Subject: [PATCH 33/35] [SPARK-45045][SS] Revert back the behavior of idle progress for StreamingQuery API from SPARK-43183 ### What changes were proposed in this pull request? This PR proposes to revert back the behavior of idle progress for StreamingQuery API from [SPARK-43183](https://issues.apache.org/jira/browse/SPARK-43183), to avoid breakage of tests from 3rd party data sources. ### Why are the changes needed? We indicated that the behavioral change from SPARK-43183 broke many tests in 3rd party data sources. (Short summary of SPARK-43183: we changed the behavior of idle progress to only provide idle event callback, instead of making progress update callback as well as adding progress for StreamingQuery API to provide as recent progresses/last progress.) The main rationale of SPARK-43183 was to avoid making progress update callback for idle event, which had been confused users. That is more about streaming query listener, and not necessarily had to change the behavior of StreamingQuery API as well. ### Does this PR introduce _any_ user-facing change? Yes, but the user-facing change is technically reduced before this PR, as we revert back the behavioral change partially from SPARK-43183, which wasn't released yet. ### How was this patch tested? Modified tests. Also manually ran 3rd party data source tests which were broken with Spark 3.5.0 RC which succeeded with this change. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #42773 from HeartSaVioR/SPARK-45045. Authored-by: Jungtaek Lim Signed-off-by: Jungtaek Lim --- .../streaming/ProgressReporter.scala | 164 ++++++++++-------- .../StreamingQueryListenerSuite.scala | 5 +- ...StreamingQueryStatusAndProgressSuite.scala | 41 ++++- 3 files changed, 135 insertions(+), 75 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 6dbecd186dc64..c0bd94e7d6cd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -89,7 +89,7 @@ trait ProgressReporter extends Logging { sparkSession.sessionState.conf.streamingNoDataProgressEventInterval // The timestamp we report an event that has not executed anything - private var lastNoExecutionProgressEventTime = triggerClock.getTimeMillis() + private var lastNoExecutionProgressEventTime = Long.MinValue private val timestampFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") // ISO8601 timestampFormat.setTimeZone(DateTimeUtils.getTimeZone("UTC")) @@ -142,21 +142,37 @@ trait ProgressReporter extends Logging { latestStreamProgress = to } - private def updateProgress(newProgress: StreamingQueryProgress): Unit = { + private def addNewProgress(newProgress: StreamingQueryProgress): Unit = { progressBuffer.synchronized { progressBuffer += newProgress while (progressBuffer.length >= sparkSession.sqlContext.conf.streamingProgressRetention) { progressBuffer.dequeue() } } + } + + private def updateProgress(newProgress: StreamingQueryProgress): Unit = { + // Reset noDataEventTimestamp if we processed any data + lastNoExecutionProgressEventTime = triggerClock.getTimeMillis() + + addNewProgress(newProgress) postEvent(new QueryProgressEvent(newProgress)) logInfo(s"Streaming query made progress: $newProgress") } - private def postIdleness(): Unit = { - postEvent(new QueryIdleEvent(id, runId, formatTimestamp(currentTriggerStartTimestamp))) - logInfo(s"Streaming query has been idle and waiting for new data more than " + - s"${noDataProgressEventInterval} ms.") + private def updateIdleness(newProgress: StreamingQueryProgress): Unit = { + val now = triggerClock.getTimeMillis() + if (now - noDataProgressEventInterval >= lastNoExecutionProgressEventTime) { + addNewProgress(newProgress) + if (lastNoExecutionProgressEventTime > Long.MinValue) { + postEvent(new QueryIdleEvent(newProgress.id, newProgress.runId, + formatTimestamp(currentTriggerStartTimestamp))) + logInfo(s"Streaming query has been idle and waiting for new data more than " + + s"$noDataProgressEventInterval ms.") + } + + lastNoExecutionProgressEventTime = now + } } /** @@ -172,96 +188,102 @@ trait ProgressReporter extends Logging { currentTriggerLatestOffsets != null) currentTriggerEndTimestamp = triggerClock.getTimeMillis() - if (hasExecuted) { - val executionStats = extractExecutionStats(hasNewData) - val processingTimeMills = currentTriggerEndTimestamp - currentTriggerStartTimestamp - val processingTimeSec = Math.max(1L, processingTimeMills).toDouble / MILLIS_PER_SECOND + val executionStats = extractExecutionStats(hasNewData, hasExecuted) + val processingTimeMills = currentTriggerEndTimestamp - currentTriggerStartTimestamp + val processingTimeSec = Math.max(1L, processingTimeMills).toDouble / MILLIS_PER_SECOND - val inputTimeSec = if (lastTriggerStartTimestamp >= 0) { - (currentTriggerStartTimestamp - lastTriggerStartTimestamp).toDouble / MILLIS_PER_SECOND - } else { - Double.PositiveInfinity - } - logDebug(s"Execution stats: $executionStats") - - val sourceProgress = sources.distinct.map { source => - val numRecords = executionStats.inputRows.getOrElse(source, 0L) - val sourceMetrics = source match { - case withMetrics: ReportsSourceMetrics => - withMetrics.metrics(Optional.ofNullable(latestStreamProgress.get(source).orNull)) - case _ => Map[String, String]().asJava - } - new SourceProgress( - description = source.toString, - startOffset = currentTriggerStartOffsets.get(source).orNull, - endOffset = currentTriggerEndOffsets.get(source).orNull, - latestOffset = currentTriggerLatestOffsets.get(source).orNull, - numInputRows = numRecords, - inputRowsPerSecond = numRecords / inputTimeSec, - processedRowsPerSecond = numRecords / processingTimeSec, - metrics = sourceMetrics - ) - } + val inputTimeSec = if (lastTriggerStartTimestamp >= 0) { + (currentTriggerStartTimestamp - lastTriggerStartTimestamp).toDouble / MILLIS_PER_SECOND + } else { + Double.PositiveInfinity + } + logDebug(s"Execution stats: $executionStats") - val sinkOutput = sinkCommitProgress.map(_.numOutputRows) - val sinkMetrics = sink match { - case withMetrics: ReportsSinkMetrics => - withMetrics.metrics() + val sourceProgress = sources.distinct.map { source => + val numRecords = executionStats.inputRows.getOrElse(source, 0L) + val sourceMetrics = source match { + case withMetrics: ReportsSourceMetrics => + withMetrics.metrics(Optional.ofNullable(latestStreamProgress.get(source).orNull)) case _ => Map[String, String]().asJava } + new SourceProgress( + description = source.toString, + startOffset = currentTriggerStartOffsets.get(source).orNull, + endOffset = currentTriggerEndOffsets.get(source).orNull, + latestOffset = currentTriggerLatestOffsets.get(source).orNull, + numInputRows = numRecords, + inputRowsPerSecond = numRecords / inputTimeSec, + processedRowsPerSecond = numRecords / processingTimeSec, + metrics = sourceMetrics + ) + } + + val sinkOutput = if (hasExecuted) { + sinkCommitProgress.map(_.numOutputRows) + } else { + sinkCommitProgress.map(_ => 0L) + } - val sinkProgress = SinkProgress( - sink.toString, sinkOutput, sinkMetrics) - - val observedMetrics = extractObservedMetrics(hasNewData, lastExecution) - - val newProgress = new StreamingQueryProgress( - id = id, - runId = runId, - name = name, - timestamp = formatTimestamp(currentTriggerStartTimestamp), - batchId = currentBatchId, - batchDuration = processingTimeMills, - durationMs = - new java.util.HashMap(currentDurationsMs.toMap.mapValues(long2Long).toMap.asJava), - eventTime = new java.util.HashMap(executionStats.eventTimeStats.asJava), - stateOperators = executionStats.stateOperators.toArray, - sources = sourceProgress.toArray, - sink = sinkProgress, - observedMetrics = new java.util.HashMap(observedMetrics.asJava)) - - // Reset noDataEventTimestamp if we processed any data - lastNoExecutionProgressEventTime = triggerClock.getTimeMillis() + val sinkMetrics = sink match { + case withMetrics: ReportsSinkMetrics => + withMetrics.metrics() + case _ => Map[String, String]().asJava + } + + val sinkProgress = SinkProgress( + sink.toString, sinkOutput, sinkMetrics) + + val observedMetrics = extractObservedMetrics(hasNewData, lastExecution) + + val newProgress = new StreamingQueryProgress( + id = id, + runId = runId, + name = name, + timestamp = formatTimestamp(currentTriggerStartTimestamp), + batchId = currentBatchId, + batchDuration = processingTimeMills, + durationMs = + new java.util.HashMap(currentDurationsMs.toMap.mapValues(long2Long).toMap.asJava), + eventTime = new java.util.HashMap(executionStats.eventTimeStats.asJava), + stateOperators = executionStats.stateOperators.toArray, + sources = sourceProgress.toArray, + sink = sinkProgress, + observedMetrics = new java.util.HashMap(observedMetrics.asJava)) + + if (hasExecuted) { updateProgress(newProgress) } else { - val now = triggerClock.getTimeMillis() - if (now - noDataProgressEventInterval >= lastNoExecutionProgressEventTime) { - lastNoExecutionProgressEventTime = now - postIdleness() - } + updateIdleness(newProgress) } currentStatus = currentStatus.copy(isTriggerActive = false) } /** Extract statistics about stateful operators from the executed query plan. */ - private def extractStateOperatorMetrics(): Seq[StateOperatorProgress] = { - assert(lastExecution != null, "lastExecution is not available") + private def extractStateOperatorMetrics(hasExecuted: Boolean): Seq[StateOperatorProgress] = { + if (lastExecution == null) return Nil + // lastExecution could belong to one of the previous triggers if `!hasExecuted`. + // Walking the plan again should be inexpensive. lastExecution.executedPlan.collect { case p if p.isInstanceOf[StateStoreWriter] => - p.asInstanceOf[StateStoreWriter].getProgress() + val progress = p.asInstanceOf[StateStoreWriter].getProgress() + if (hasExecuted) { + progress + } else { + progress.copy(newNumRowsUpdated = 0, newNumRowsDroppedByWatermark = 0) + } } } /** Extracts statistics from the most recent query execution. */ - private def extractExecutionStats(hasNewData: Boolean): ExecutionStats = { + private def extractExecutionStats(hasNewData: Boolean, hasExecuted: Boolean): ExecutionStats = { val hasEventTime = logicalPlan.collect { case e: EventTimeWatermark => e }.nonEmpty val watermarkTimestamp = if (hasEventTime) Map("watermark" -> formatTimestamp(offsetSeqMetadata.batchWatermarkMs)) else Map.empty[String, String] // SPARK-19378: Still report metrics even though no data was processed while reporting progress. - val stateOperators = extractStateOperatorMetrics() + val stateOperators = extractStateOperatorMetrics(hasExecuted) if (!hasNewData) { return ExecutionStats(Map.empty, stateOperators, watermarkTimestamp) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 5b5e8732e0dc9..52b740bc5c34f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -331,9 +331,9 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { } true } - // `recentProgress` should not receive any events + // `recentProgress` should not receive too many no data events actions += AssertOnQuery { q => - q.recentProgress.isEmpty + q.recentProgress.size > 1 && q.recentProgress.size <= 11 } testStream(input.toDS)(actions.toSeq: _*) spark.sparkContext.listenerBus.waitUntilEmpty() @@ -524,7 +524,6 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { testStream(result)( StartStream(trigger = Trigger.ProcessingTime(10), triggerClock = clock), AddData(input, 10), - // checkProgressEvent(1), AdvanceManualClock(10), checkProgressEvent(1), AdvanceManualClock(90), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala index 1b6005257c0ae..28134ec9d9144 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala @@ -24,11 +24,13 @@ import scala.collection.JavaConverters._ import org.json4s.jackson.JsonMethods._ import org.scalatest.concurrent.Eventually import org.scalatest.concurrent.PatienceConfiguration.Timeout +import org.scalatest.time.SpanSugar._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQueryStatusAndProgressSuite._ import org.apache.spark.sql.streaming.StreamingQuerySuite.clock import org.apache.spark.sql.streaming.util.StreamManualClock @@ -288,6 +290,42 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { } } + test("SPARK-19378: Continue reporting stateOp metrics even if there is no active trigger") { + import testImplicits._ + + withSQLConf(SQLConf.STREAMING_NO_DATA_PROGRESS_EVENT_INTERVAL.key -> "10") { + val inputData = MemoryStream[Int] + + val query = inputData.toDS().toDF("value") + .select($"value") + .groupBy($"value") + .agg(count("*")) + .writeStream + .queryName("metric_continuity") + .format("memory") + .outputMode("complete") + .start() + try { + inputData.addData(1, 2) + query.processAllAvailable() + + val progress = query.lastProgress + assert(progress.stateOperators.length > 0) + // Should emit new progresses every 10 ms, but we could be facing a slow Jenkins + eventually(timeout(1.minute)) { + val nextProgress = query.lastProgress + assert(nextProgress.timestamp !== progress.timestamp) + assert(nextProgress.numInputRows === 0) + assert(nextProgress.stateOperators.head.numRowsTotal === 2) + assert(nextProgress.stateOperators.head.numRowsUpdated === 0) + assert(nextProgress.sink.numOutputRows === 0) + } + } finally { + query.stop() + } + } + } + test("SPARK-29973: Make `processedRowsPerSecond` calculated more accurately and meaningfully") { import testImplicits._ @@ -300,7 +338,8 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { AdvanceManualClock(1000), waitUntilBatchProcessed, AssertOnQuery(query => { - assert(query.lastProgress == null) + assert(query.lastProgress.numInputRows == 0) + assert(query.lastProgress.processedRowsPerSecond == 0.0d) true }), AddData(inputData, 1, 2), From 74c1f02531e78dde34cbd311c3ed8feed7aa7fe5 Mon Sep 17 00:00:00 2001 From: Ivan Sadikov Date: Mon, 4 Sep 2023 12:10:07 +0900 Subject: [PATCH 34/35] [SPARK-44940][SQL] Improve performance of JSON parsing when "spark.sql.json.enablePartialResults" is enabled ### What changes were proposed in this pull request? The PR improves JSON parsing when `spark.sql.json.enablePartialResults` is enabled: - Fixes the issue when using nested arrays `ClassCastException: org.apache.spark.sql.catalyst.util.GenericArrayData cannot be cast to org.apache.spark.sql.catalyst.InternalRow` - Improves parsing of the nested struct fields, e.g. `{"a1": "AAA", "a2": [{"f1": "", "f2": ""}], "a3": "id1", "a4": "XXX"}` used to be parsed as `|AAA|NULL |NULL|NULL|` and now is parsed as `|AAA|[{NULL, }]|id1|XXX|`. - Improves performance of nested JSON parsing. The initial implementation would throw too many exceptions when multiple nested fields failed to parse. When the config is disabled, it is not a problem because the entire record is marked as NULL. The internal benchmarks show the performance improvement from slowdown of over 160% to an improvement of 7-8% compared to the master branch when the flag is enabled. I will create a follow-up ticket to add a benchmark for this regression. ### Why are the changes needed? Fixes some corner cases in JSON parsing and improves performance when `spark.sql.json.enablePartialResults` is enabled. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? I added tests to verify nested structs, maps, and arrays can be parsed without affecting the subsequent fields in the JSON. I also updated the existing tests when `spark.sql.json.enablePartialResults` is enabled because we parse more data now. I added a benchmark to check performance. Before the change (master, a45a3a3d60cb97b107a177ad16bfe36372bc3e9b): ``` [info] OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws [info] Intel(R) Xeon(R) Platinum 8375C CPU 2.90GHz [info] Partial JSON results: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] parse invalid JSON 9537 9820 452 0.0 953651.6 1.0X ``` After the change (this PR): ``` OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws Intel(R) Xeon(R) Platinum 8375C CPU 2.90GHz Partial JSON results: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ parse invalid JSON 3100 3106 6 0.0 309967.6 1.0X ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #42667 from sadikovi/SPARK-44940. Authored-by: Ivan Sadikov Signed-off-by: Hyukjin Kwon --- .../sql/catalyst/json/JacksonParser.scala | 41 ++++- .../catalyst/util/BadRecordException.scala | 55 +++++- .../sql/errors/QueryExecutionErrors.scala | 12 +- sql/core/benchmarks/JsonBenchmark-results.txt | 152 ++++++++-------- .../apache/spark/sql/JsonFunctionsSuite.scala | 20 ++- .../datasources/json/JsonBenchmark.scala | 28 +++ .../datasources/json/JsonSuite.scala | 170 +++++++++++++++++- 7 files changed, 384 insertions(+), 94 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 91c17a475cd94..eace96ac87291 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -420,17 +420,17 @@ class JacksonParser( case VALUE_STRING if parser.getTextLength < 1 && allowEmptyString => dataType match { case FloatType | DoubleType | TimestampType | DateType => - throw QueryExecutionErrors.emptyJsonFieldValueError(dataType) + throw EmptyJsonFieldValueException(dataType) case _ => null } case VALUE_STRING if parser.getTextLength < 1 => - throw QueryExecutionErrors.emptyJsonFieldValueError(dataType) + throw EmptyJsonFieldValueException(dataType) case token => // We cannot parse this token based on the given data type. So, we throw a // RuntimeException and this exception will be caught by `parse` method. - throw QueryExecutionErrors.cannotParseJSONFieldError(parser, token, dataType) + throw CannotParseJSONFieldException(parser.getCurrentName, parser.getText, token, dataType) } /** @@ -459,6 +459,11 @@ class JacksonParser( bitmask(index) = false } catch { case e: SparkUpgradeException => throw e + case err: PartialValueException if enablePartialResults => + badRecordException = badRecordException.orElse(Some(err.cause)) + row.update(index, err.partialResult) + skipRow = structFilters.skipRow(row, index) + bitmask(index) = false case NonFatal(e) if isRoot || enablePartialResults => badRecordException = badRecordException.orElse(Some(e)) parser.skipChildren() @@ -508,7 +513,7 @@ class JacksonParser( if (badRecordException.isEmpty) { mapData } else { - throw PartialResultException(InternalRow(mapData), badRecordException.get) + throw PartialMapDataResultException(mapData, badRecordException.get) } } @@ -543,10 +548,21 @@ class JacksonParser( throw PartialResultArrayException(arrayData.toArray[InternalRow](schema), badRecordException.get) } else { - throw PartialResultException(InternalRow(arrayData), badRecordException.get) + throw PartialArrayDataResultException(arrayData, badRecordException.get) } } + /** + * Converts the non-stacktrace exceptions to user-friendly QueryExecutionErrors. + */ + private def convertCauseForPartialResult(err: Throwable): Throwable = err match { + case CannotParseJSONFieldException(fieldName, fieldValue, jsonType, dataType) => + QueryExecutionErrors.cannotParseJSONFieldError(fieldName, fieldValue, jsonType, dataType) + case EmptyJsonFieldValueException(dataType) => + QueryExecutionErrors.emptyJsonFieldValueError(dataType) + case _ => err + } + /** * Parse the JSON input to the set of [[InternalRow]]s. * @@ -589,12 +605,25 @@ class JacksonParser( throw BadRecordException( record = () => recordLiteral(record), partialResults = () => Array(row), - cause) + convertCauseForPartialResult(cause)) case PartialResultArrayException(rows, cause) => throw BadRecordException( record = () => recordLiteral(record), partialResults = () => rows, cause) + // These exceptions should never be thrown outside of JacksonParser. + // They are used for the control flow in the parser. We add them here for completeness + // since they also indicate a bad record. + case PartialArrayDataResultException(arrayData, cause) => + throw BadRecordException( + record = () => recordLiteral(record), + partialResults = () => Array(InternalRow(arrayData)), + convertCauseForPartialResult(cause)) + case PartialMapDataResultException(mapData, cause) => + throw BadRecordException( + record = () => recordLiteral(record), + partialResults = () => Array(InternalRow(mapData)), + convertCauseForPartialResult(cause)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala index 7bf01fba8cd9b..65a56c1064e45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala @@ -17,19 +17,44 @@ package org.apache.spark.sql.catalyst.util +import com.fasterxml.jackson.core.JsonToken + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.DataType import org.apache.spark.unsafe.types.UTF8String +abstract class PartialValueException(val cause: Throwable) extends Exception(cause) { + def partialResult: Serializable + override def getStackTrace(): Array[StackTraceElement] = cause.getStackTrace() + override def fillInStackTrace(): Throwable = this +} + /** - * Exception thrown when the underlying parser returns a partial result of parsing. + * Exception thrown when the underlying parser returns a partial result of parsing an object/row. * @param partialResult the partial result of parsing a bad record. * @param cause the actual exception about why the parser cannot return full result. */ case class PartialResultException( - partialResult: InternalRow, - cause: Throwable) - extends Exception(cause) + override val partialResult: InternalRow, + override val cause: Throwable) extends PartialValueException(cause) + +/** + * Exception thrown when the underlying parser returns a partial array result. + * @param partialResult the partial array result. + * @param cause the actual exception about why the parser cannot return full result. + */ +case class PartialArrayDataResultException( + override val partialResult: ArrayData, + override val cause: Throwable) extends PartialValueException(cause) + +/** + * Exception thrown when the underlying parser returns a partial map result. + * @param partialResult the partial map result. + * @param cause the actual exception about why the parser cannot return full result. + */ +case class PartialMapDataResultException( + override val partialResult: MapData, + override val cause: Throwable) extends PartialValueException(cause) /** * Exception thrown when the underlying parser returns partial result list of parsing. @@ -65,3 +90,25 @@ case class StringAsDataTypeException( fieldName: String, fieldValue: String, dataType: DataType) extends RuntimeException() + +/** + * No-stacktrace equivalent of `QueryExecutionErrors.cannotParseJSONFieldError`. + * Used for code control flow in the parser without overhead of creating a full exception. + */ +case class CannotParseJSONFieldException( + fieldName: String, + fieldValue: String, + jsonType: JsonToken, + dataType: DataType) extends RuntimeException() { + override def getStackTrace(): Array[StackTraceElement] = new Array[StackTraceElement](0) + override def fillInStackTrace(): Throwable = this +} + +/** + * No-stacktrace equivalent of `QueryExecutionErrors.emptyJsonFieldValueError`. + * Used for code control flow in the parser without overhead of creating a full exception. + */ +case class EmptyJsonFieldValueException(dataType: DataType) extends RuntimeException() { + override def getStackTrace(): Array[StackTraceElement] = new Array[StackTraceElement](0) + override def fillInStackTrace(): Throwable = this +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 8e80d6570c4a5..2d655be0e700c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -1279,11 +1279,19 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def cannotParseJSONFieldError(parser: JsonParser, jsonType: JsonToken, dataType: DataType) : SparkRuntimeException = { + cannotParseJSONFieldError(parser.getCurrentName, parser.getText, jsonType, dataType) + } + + def cannotParseJSONFieldError( + fieldName: String, + fieldValue: String, + jsonType: JsonToken, + dataType: DataType): SparkRuntimeException = { new SparkRuntimeException( errorClass = "CANNOT_PARSE_JSON_FIELD", messageParameters = Map( - "fieldName" -> toSQLValue(parser.getCurrentName, StringType), - "fieldValue" -> parser.getText, + "fieldName" -> toSQLValue(fieldName, StringType), + "fieldValue" -> fieldValue, "jsonType" -> jsonType.toString(), "dataType" -> toSQLType(dataType))) } diff --git a/sql/core/benchmarks/JsonBenchmark-results.txt b/sql/core/benchmarks/JsonBenchmark-results.txt index 55f66f7bb24ed..e53c780114184 100644 --- a/sql/core/benchmarks/JsonBenchmark-results.txt +++ b/sql/core/benchmarks/JsonBenchmark-results.txt @@ -3,121 +3,125 @@ Benchmark for performance of JSON parsing ================================================================================================ Preparing data for benchmarking ... -OpenJDK 64-Bit Server VM 1.8.0_362-b09 on Linux 5.15.0-1037-azure -Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz JSON schema inferring: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -No encoding 3720 3843 121 1.3 743.9 1.0X -UTF-8 is set 5412 5455 45 0.9 1082.4 0.7X +No encoding 2084 2134 46 2.4 416.8 1.0X +UTF-8 is set 3077 3093 14 1.6 615.3 0.7X Preparing data for benchmarking ... -OpenJDK 64-Bit Server VM 1.8.0_362-b09 on Linux 5.15.0-1037-azure -Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz count a short column: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -No encoding 3234 3254 33 1.5 646.7 1.0X -UTF-8 is set 4847 4868 21 1.0 969.5 0.7X +No encoding 2854 2863 8 1.8 570.8 1.0X +UTF-8 is set 4066 4066 1 1.2 813.1 0.7X Preparing data for benchmarking ... -OpenJDK 64-Bit Server VM 1.8.0_362-b09 on Linux 5.15.0-1037-azure -Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz count a wide column: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -No encoding 5702 5794 101 0.2 5702.1 1.0X -UTF-8 is set 9526 9607 73 0.1 9526.1 0.6X +No encoding 3348 3368 26 0.3 3347.8 1.0X +UTF-8 is set 5215 5239 22 0.2 5214.7 0.6X Preparing data for benchmarking ... -OpenJDK 64-Bit Server VM 1.8.0_362-b09 on Linux 5.15.0-1037-azure -Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz select wide row: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -No encoding 18318 18448 199 0.0 366367.7 1.0X -UTF-8 is set 19791 19887 99 0.0 395817.1 0.9X +No encoding 11046 11102 54 0.0 220928.4 1.0X +UTF-8 is set 12135 12181 54 0.0 242697.4 0.9X Preparing data for benchmarking ... -OpenJDK 64-Bit Server VM 1.8.0_362-b09 on Linux 5.15.0-1037-azure -Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz Select a subset of 10 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Select 10 columns 2531 2570 51 0.4 2531.3 1.0X -Select 1 column 1867 1882 16 0.5 1867.0 1.4X +Select 10 columns 2486 2488 2 0.4 2486.5 1.0X +Select 1 column 1505 1506 2 0.7 1504.6 1.7X Preparing data for benchmarking ... -OpenJDK 64-Bit Server VM 1.8.0_362-b09 on Linux 5.15.0-1037-azure -Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz creation of JSON parser per line: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Short column without encoding 868 875 7 1.2 868.4 1.0X -Short column with UTF-8 1151 1163 11 0.9 1150.9 0.8X -Wide column without encoding 12063 12299 205 0.1 12063.0 0.1X -Wide column with UTF-8 16095 16136 51 0.1 16095.3 0.1X +Short column without encoding 888 889 3 1.1 887.6 1.0X +Short column with UTF-8 1134 1136 2 0.9 1134.3 0.8X +Wide column without encoding 8012 8056 51 0.1 8012.4 0.1X +Wide column with UTF-8 9830 9844 22 0.1 9829.7 0.1X Preparing data for benchmarking ... -OpenJDK 64-Bit Server VM 1.8.0_362-b09 on Linux 5.15.0-1037-azure -Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz JSON functions: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Text read 165 170 4 6.1 164.7 1.0X -from_json 2339 2386 77 0.4 2338.9 0.1X -json_tuple 2667 2730 55 0.4 2667.3 0.1X -get_json_object 2627 2659 32 0.4 2627.1 0.1X +Text read 85 87 2 11.7 85.4 1.0X +from_json 1706 1711 4 0.6 1706.4 0.1X +json_tuple 1528 1534 7 0.7 1528.2 0.1X +get_json_object 1275 1286 17 0.8 1275.0 0.1X Preparing data for benchmarking ... -OpenJDK 64-Bit Server VM 1.8.0_362-b09 on Linux 5.15.0-1037-azure -Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz Dataset of json strings: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Text read 700 715 20 7.1 140.1 1.0X -schema inferring 3144 3166 20 1.6 628.7 0.2X -parsing 3261 3271 9 1.5 652.1 0.2X +Text read 369 370 1 13.6 73.8 1.0X +schema inferring 1880 1883 4 2.7 376.0 0.2X +parsing 3731 3737 8 1.3 746.1 0.1X Preparing data for benchmarking ... -OpenJDK 64-Bit Server VM 1.8.0_362-b09 on Linux 5.15.0-1037-azure -Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz Json files in the per-line mode: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Text read 1096 1105 12 4.6 219.1 1.0X -Schema inferring 3818 3830 16 1.3 763.6 0.3X -Parsing without charset 4107 4137 32 1.2 821.4 0.3X -Parsing with UTF-8 5717 5763 41 0.9 1143.3 0.2X +Text read 553 579 32 9.0 110.6 1.0X +Schema inferring 2195 2196 2 2.3 439.0 0.3X +Parsing without charset 4272 4274 3 1.2 854.3 0.1X +Parsing with UTF-8 5459 5464 6 0.9 1091.7 0.1X -OpenJDK 64-Bit Server VM 1.8.0_362-b09 on Linux 5.15.0-1037-azure -Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz Write dates and timestamps: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Create a dataset of timestamps 199 202 3 5.0 198.9 1.0X -to_json(timestamp) 1458 1487 26 0.7 1458.0 0.1X -write timestamps to files 1232 1262 26 0.8 1232.5 0.2X -Create a dataset of dates 231 237 5 4.3 230.8 0.9X -to_json(date) 956 966 10 1.0 956.5 0.2X -write dates to files 785 793 10 1.3 785.4 0.3X +Create a dataset of timestamps 102 112 13 9.8 101.9 1.0X +to_json(timestamp) 840 841 1 1.2 839.6 0.1X +write timestamps to files 692 696 4 1.4 692.0 0.1X +Create a dataset of dates 120 121 1 8.4 119.7 0.9X +to_json(date) 589 591 2 1.7 589.3 0.2X +write dates to files 442 445 2 2.3 442.3 0.2X -OpenJDK 64-Bit Server VM 1.8.0_362-b09 on Linux 5.15.0-1037-azure -Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz Read dates and timestamps: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ----------------------------------------------------------------------------------------------------------------------------------------------------- -read timestamp text from files 294 300 6 3.4 293.8 1.0X -read timestamps from files 3254 3283 49 0.3 3254.0 0.1X -infer timestamps from files 8390 8528 165 0.1 8389.8 0.0X -read date text from files 269 276 7 3.7 269.3 1.1X -read date from files 1178 1192 13 0.8 1177.8 0.2X -timestamp strings 406 418 15 2.5 406.2 0.7X -parse timestamps from Dataset[String] 3700 3713 16 0.3 3699.5 0.1X -infer timestamps from Dataset[String] 8604 8647 65 0.1 8604.0 0.0X -date strings 464 479 14 2.2 463.7 0.6X -parse dates from Dataset[String] 1528 1538 10 0.7 1527.7 0.2X -from_json(timestamp) 5402 5429 26 0.2 5401.8 0.1X -from_json(date) 2948 2966 17 0.3 2947.5 0.1X -infer error timestamps from Dataset[String] with default format 2358 2434 70 0.4 2357.6 0.1X -infer error timestamps from Dataset[String] with user-provided format 2363 2390 36 0.4 2362.9 0.1X -infer error timestamps from Dataset[String] with legacy format 2248 2287 35 0.4 2248.3 0.1X +read timestamp text from files 143 145 4 7.0 142.6 1.0X +read timestamps from files 2449 2469 17 0.4 2448.6 0.1X +infer timestamps from files 5579 5596 15 0.2 5578.8 0.0X +read date text from files 132 139 7 7.6 131.7 1.1X +read date from files 1017 1020 2 1.0 1017.5 0.1X +timestamp strings 227 230 3 4.4 227.2 0.6X +parse timestamps from Dataset[String] 2827 2830 3 0.4 2826.5 0.1X +infer timestamps from Dataset[String] 6001 6008 6 0.2 6001.2 0.0X +date strings 259 261 2 3.9 259.0 0.6X +parse dates from Dataset[String] 1382 1387 4 0.7 1382.3 0.1X +from_json(timestamp) 3557 3561 7 0.3 3556.8 0.0X +from_json(date) 2146 2148 2 0.5 2146.4 0.1X +infer error timestamps from Dataset[String] with default format 1989 1993 4 0.5 1989.3 0.1X +infer error timestamps from Dataset[String] with user-provided format 1922 1925 3 0.5 1922.1 0.1X +infer error timestamps from Dataset[String] with legacy format 1919 1923 4 0.5 1919.1 0.1X -OpenJDK 64-Bit Server VM 1.8.0_362-b09 on Linux 5.15.0-1037-azure -Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz Filters pushdown: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -w/o filters 22544 22661 109 0.0 225436.4 1.0X -pushdown disabled 21045 21213 188 0.0 210452.6 1.1X -w/ filters 893 904 10 0.1 8931.8 25.2X - +w/o filters 14387 14399 12 0.0 143871.9 1.0X +pushdown disabled 13891 13899 7 0.0 138912.3 1.0X +w/ filters 782 784 2 0.1 7820.5 18.4X +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz +Partial JSON results: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +parse invalid JSON 3100 3106 6 0.0 309967.6 1.0X diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 187fab75f6378..a76e102fe913f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -1021,16 +1021,23 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { .add("c2", ArrayType(new StructType().add("c3", LongType).add("c4", StringType))) val df1 = Seq("""{"c2": [19], "c1": 123456}""").toDF("c0") checkAnswer(df1.select(from_json($"c0", st)), Row(Row(123456, null))) + val df2 = Seq("""{"data": {"c2": [19], "c1": 123456}}""").toDF("c0") - checkAnswer(df2.select(from_json($"c0", new StructType().add("data", st))), Row(Row(null))) + withSQLConf(SQLConf.JSON_ENABLE_PARTIAL_RESULTS.key -> "true") { + checkAnswer( + df2.select(from_json($"c0", new StructType().add("data", st))), + Row(Row(Row(123456, null))) + ) + } + withSQLConf(SQLConf.JSON_ENABLE_PARTIAL_RESULTS.key -> "false") { + checkAnswer(df2.select(from_json($"c0", new StructType().add("data", st))), Row(Row(null))) + } + val df3 = Seq("""[{"c2": [19], "c1": 123456}]""").toDF("c0") withSQLConf(SQLConf.JSON_ENABLE_PARTIAL_RESULTS.key -> "true") { - val df3 = Seq("""[{"c2": [19], "c1": 123456}]""").toDF("c0") checkAnswer(df3.select(from_json($"c0", ArrayType(st))), Row(Array(Row(123456, null)))) } - withSQLConf(SQLConf.JSON_ENABLE_PARTIAL_RESULTS.key -> "false") { - val df3 = Seq("""[{"c2": [19], "c1": 123456}]""").toDF("c0") checkAnswer(df3.select(from_json($"c0", ArrayType(st))), Row(null)) } @@ -1119,14 +1126,13 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { ) ) - // Value "a" cannot be parsed as an integer, - // the error cascades to "c2", thus making its value null. + // Value "a" cannot be parsed as an integer, c2 value is null. val df = Seq("""[{"c1": [{"c2": ["a"]}]}]""").toDF("c0") withSQLConf(SQLConf.JSON_ENABLE_PARTIAL_RESULTS.key -> "true") { checkAnswer( df.select(from_json($"c0", ArrayType(st))), - Row(Array(Row(null))) + Row(Array(Row(Seq(Row(null))))) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmark.scala index c522378a65d7c..5b86543648f07 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmark.scala @@ -542,6 +542,33 @@ object JsonBenchmark extends SqlBasedBenchmark { } } + private def partialResultBenchmark(rowsNum: Int, numIters: Int): Unit = { + val benchmark = new Benchmark("Partial JSON results", rowsNum, output = output) + val colsNum = 1000 + + val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType)) + val schema = StructType(fields) + + def data: Dataset[String] = { + spark.range(0, rowsNum, 1, 1).mapPartitions { iter => + iter.map { i => + (0 until colsNum).map { j => + // Only the last column has an integer value. + if (j < colsNum - 1) s""""col${i}":"foo_${j}"""" else s""""col${i}":${j}""" + }.mkString("{", ", ", "}") + } + }.select($"value").as[String] + } + + benchmark.addCase("parse invalid JSON", numIters) { _ => + withSQLConf(SQLConf.JSON_ENABLE_PARTIAL_RESULTS.key -> "true") { + spark.read.schema(schema).json(data).noop() + } + } + + benchmark.run() + } + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { val numIters = 3 runBenchmark("Benchmark for performance of JSON parsing") { @@ -558,6 +585,7 @@ object JsonBenchmark extends SqlBasedBenchmark { // Benchmark pushdown filters that refer to top-level columns. // TODO (SPARK-32325): Add benchmarks for filters with nested column attributes. filtersPushdownBenchmark(rowsNum = 100 * 1000, numIters) + partialResultBenchmark(rowsNum = 10000, numIters) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 95468a1f1d77c..11779286ec25f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -3437,7 +3437,7 @@ abstract class JsonSuite if (enablePartialResults) { checkAnswer( df, - Seq(Row(null, Row(1)), Row(Row(2, null), Row(2))) + Seq(Row(Row(1, null), Row(1)), Row(Row(2, null), Row(2))) ) } else { checkAnswer( @@ -3450,6 +3450,174 @@ abstract class JsonSuite } } + test("SPARK-44940: fully parse the record except f1 if partial results are enabled") { + withTempPath { path => + Seq( + """{"a1": "AAA", "a2": [{"f1": "", "f2": ""}], "a3": "id1", "a4": "XXX"}""", + """{"a1": "BBB", "a2": [{"f1": 12, "f2": ""}], "a3": "id2", "a4": "YYY"}""").toDF() + .repartition(1) + .write.text(path.getAbsolutePath) + + withSQLConf(SQLConf.JSON_ENABLE_PARTIAL_RESULTS.key -> "true") { + val df = spark.read.json(path.getAbsolutePath) + checkAnswer( + df, + Seq( + Row("AAA", Seq(Row(null, "")), "id1", "XXX"), + Row("BBB", Seq(Row(12, "")), "id2", "YYY") + ) + ) + } + + withSQLConf(SQLConf.JSON_ENABLE_PARTIAL_RESULTS.key -> "false") { + val df = spark.read.json(path.getAbsolutePath) + checkAnswer( + df, + Seq( + Row("AAA", null, null, null), + Row("BBB", Seq(Row(12, "")), "id2", "YYY") + ) + ) + } + } + } + + test("SPARK-44940: fully parse primitive map if partial results are enabled") { + withTempPath { path => + Seq( + """{"a1": "AAA", "a2": {"f1": "", "f2": ""}, "a3": "id1"}""", + """{"a1": "BBB", "a2": {"f1": 12, "f2": ""}, "a3": "id2"}""").toDF() + .repartition(1) + .write.text(path.getAbsolutePath) + + val schema = "a1 string, a2 map, a3 string" + + withSQLConf(SQLConf.JSON_ENABLE_PARTIAL_RESULTS.key -> "true") { + val df = spark.read.schema(schema).json(path.getAbsolutePath) + // Although the keys match the string type and some values match the integer type, because + // some of the values do not match the type, we mark the entire map as null. + checkAnswer( + df, + Seq( + Row("AAA", null, "id1"), + Row("BBB", null, "id2") + ) + ) + } + + withSQLConf(SQLConf.JSON_ENABLE_PARTIAL_RESULTS.key -> "false") { + val df = spark.read.schema(schema).json(path.getAbsolutePath) + checkAnswer( + df, + Seq( + Row("AAA", null, null), + Row("BBB", null, null) + ) + ) + } + } + } + + test("SPARK-44940: fully parse map of structs if partial results are enabled") { + withTempPath { path => + Seq( + """{"a1": "AAA", "a2": {"key": {"f1": "", "f2": ""}}, "a3": "id1"}""", + """{"a1": "BBB", "a2": {"key": {"f1": 12, "f2": ""}}, "a3": "id2"}""").toDF() + .repartition(1) + .write.text(path.getAbsolutePath) + + val schema = "a1 string, a2 map>, a3 string" + + withSQLConf(SQLConf.JSON_ENABLE_PARTIAL_RESULTS.key -> "true") { + val df = spark.read.schema(schema).json(path.getAbsolutePath) + checkAnswer( + df, + Seq( + Row("AAA", Map("key" -> Row(null, "")), "id1"), + Row("BBB", Map("key" -> Row(12, "")), "id2") + ) + ) + } + + withSQLConf(SQLConf.JSON_ENABLE_PARTIAL_RESULTS.key -> "false") { + val df = spark.read.schema(schema).json(path.getAbsolutePath) + checkAnswer( + df, + Seq( + Row("AAA", null, null), + Row("BBB", Map("key" -> Row(12, "")), "id2") + ) + ) + } + } + } + + test("SPARK-44940: fully parse primitive arrays if partial results are enabled") { + withTempPath { path => + Seq( + """{"a1": "AAA", "a2": {"f1": [""]}, "a3": "id1", "a4": "XXX"}""", + """{"a1": "BBB", "a2": {"f1": [12]}, "a3": "id2", "a4": "YYY"}""").toDF() + .repartition(1) + .write.text(path.getAbsolutePath) + + withSQLConf(SQLConf.JSON_ENABLE_PARTIAL_RESULTS.key -> "true") { + val df = spark.read.json(path.getAbsolutePath) + checkAnswer( + df, + Seq( + Row("AAA", Row(null), "id1", "XXX"), + Row("BBB", Row(Seq(12)), "id2", "YYY") + ) + ) + } + + withSQLConf(SQLConf.JSON_ENABLE_PARTIAL_RESULTS.key -> "false") { + val df = spark.read.json(path.getAbsolutePath) + checkAnswer( + df, + Seq( + Row("AAA", null, null, null), + Row("BBB", Row(Seq(12)), "id2", "YYY") + ) + ) + } + } + } + + test("SPARK-44940: fully parse array of arrays if partial results are enabled") { + withTempPath { path => + Seq( + """{"a1": "AAA", "a2": [[12, ""], [""]], "a3": "id1", "a4": "XXX"}""", + """{"a1": "BBB", "a2": [[12, 34], [""]], "a3": "id2", "a4": "YYY"}""").toDF() + .repartition(1) + .write.text(path.getAbsolutePath) + + // We cannot parse `array>` type because one of the inner arrays contains a + // mismatched type. + withSQLConf(SQLConf.JSON_ENABLE_PARTIAL_RESULTS.key -> "true") { + val df = spark.read.json(path.getAbsolutePath) + checkAnswer( + df, + Seq( + Row("AAA", null, "id1", "XXX"), + Row("BBB", null, "id2", "YYY") + ) + ) + } + + withSQLConf(SQLConf.JSON_ENABLE_PARTIAL_RESULTS.key -> "false") { + val df = spark.read.json(path.getAbsolutePath) + checkAnswer( + df, + Seq( + Row("AAA", null, "id1", "XXX"), + Row("BBB", null, "id2", "YYY") + ) + ) + } + } + } + test("SPARK-40667: validate JSON Options") { assert(JSONOptions.getAllOptions.size == 28) // Please add validation on any new Json options here From 60d8fc49bec5dae1b8cf39a0670cb640b430f520 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Mon, 4 Sep 2023 11:48:52 +0800 Subject: [PATCH 35/35] [SPARK-45042][BUILD] Upgrade jetty to 9.4.52.v20230823 ### What changes were proposed in this pull request? The pr aims to Upgrade jetty from 9.4.51.v20230217 to 9.4.52.v20230823. ### Why are the changes needed? - This is a release of the https://github.com/eclipse/jetty.project/issues/7958 that was sponsored by a [support contract from Webtide.com](mailto:saleswebtide.com) - The newest version fix a possible security issue: This release provides a workaround for Security Advisory https://github.com/advisories/GHSA-58qw-p7qm-5rvh - The release note as follows: https://github.com/eclipse/jetty.project/releases/tag/jetty-9.4.52.v20230823 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #42761 from panbingkun/SPARK-45042. Authored-by: panbingkun Signed-off-by: yangjie01 --- dev/deps/spark-deps-hadoop-3-hive-2.3 | 4 ++-- pom.xml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 94c999ba4b6d8..59164c1f8f441 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -128,8 +128,8 @@ jersey-container-servlet/2.40//jersey-container-servlet-2.40.jar jersey-hk2/2.40//jersey-hk2-2.40.jar jersey-server/2.40//jersey-server-2.40.jar jettison/1.5.4//jettison-1.5.4.jar -jetty-util-ajax/9.4.51.v20230217//jetty-util-ajax-9.4.51.v20230217.jar -jetty-util/9.4.51.v20230217//jetty-util-9.4.51.v20230217.jar +jetty-util-ajax/9.4.52.v20230823//jetty-util-ajax-9.4.52.v20230823.jar +jetty-util/9.4.52.v20230823//jetty-util-9.4.52.v20230823.jar jline/2.14.6//jline-2.14.6.jar joda-time/2.12.5//joda-time-2.12.5.jar jodd-core/3.5.2//jodd-core-3.5.2.jar diff --git a/pom.xml b/pom.xml index b64a0ab15acd2..8edc3fd550c2e 100644 --- a/pom.xml +++ b/pom.xml @@ -143,7 +143,7 @@ 1.13.1 1.9.1 shaded-protobuf - 9.4.51.v20230217 + 9.4.52.v20230823 4.0.3 0.10.0