From 1a89bdc60d55394a1a9d94d4fa69fa5ab8041671 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 3 Aug 2023 11:57:34 +0900 Subject: [PATCH 01/68] [SPARK-44620][SQL][PS][CONNECT] Make `ResolvePivot` retain the `Plan_ID_TAG` ### What changes were proposed in this pull request? Make `ResolvePivot` retain the `Plan_ID_TAG` ### Why are the changes needed? to resolve the `AnalysisException` in Pandas APIs on Connect ### Does this PR introduce _any_ user-facing change? yes, new APIs enabled: 1. `frame.pivot_table` 2. `frame.transpose` 3. `series.unstack` ### How was this patch tested? enabled UTs Closes #42261 from zhengruifeng/ps_connect_analyze_pivot. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- .../connect/computation/test_parity_pivot.py | 17 +---------------- .../connect/frame/test_parity_reshaping.py | 11 +---------- .../tests/connect/series/test_parity_compute.py | 6 +----- .../tests/connect/test_parity_categorical.py | 6 ------ .../spark/sql/catalyst/analysis/Analyzer.scala | 10 +++++++--- 5 files changed, 10 insertions(+), 40 deletions(-) diff --git a/python/pyspark/pandas/tests/connect/computation/test_parity_pivot.py b/python/pyspark/pandas/tests/connect/computation/test_parity_pivot.py index d2c4f9ae60717..c8ec48eb06aa4 100644 --- a/python/pyspark/pandas/tests/connect/computation/test_parity_pivot.py +++ b/python/pyspark/pandas/tests/connect/computation/test_parity_pivot.py @@ -16,28 +16,13 @@ # import unittest -from pyspark import pandas as ps from pyspark.pandas.tests.computation.test_pivot import FramePivotMixin from pyspark.testing.connectutils import ReusedConnectTestCase from pyspark.testing.pandasutils import PandasOnSparkTestUtils class FrameParityPivotTests(FramePivotMixin, PandasOnSparkTestUtils, ReusedConnectTestCase): - @property - def psdf(self): - return ps.from_pandas(self.pdf) - - @unittest.skip( - "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark Connect client." - ) - def test_pivot_table(self): - super().test_pivot_table() - - @unittest.skip( - "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark Connect client." - ) - def test_pivot_table_dtypes(self): - super().test_pivot_table_dtypes() + pass if __name__ == "__main__": diff --git a/python/pyspark/pandas/tests/connect/frame/test_parity_reshaping.py b/python/pyspark/pandas/tests/connect/frame/test_parity_reshaping.py index 98ebf3ca44a07..e4bac7b078e66 100644 --- a/python/pyspark/pandas/tests/connect/frame/test_parity_reshaping.py +++ b/python/pyspark/pandas/tests/connect/frame/test_parity_reshaping.py @@ -16,22 +16,13 @@ # import unittest -from pyspark import pandas as ps from pyspark.pandas.tests.frame.test_reshaping import FrameReshapingMixin from pyspark.testing.connectutils import ReusedConnectTestCase from pyspark.testing.pandasutils import PandasOnSparkTestUtils class FrameParityReshapingTests(FrameReshapingMixin, PandasOnSparkTestUtils, ReusedConnectTestCase): - @property - def psdf(self): - return ps.from_pandas(self.pdf) - - @unittest.skip( - "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark Connect client." - ) - def test_transpose(self): - super().test_transpose() + pass if __name__ == "__main__": diff --git a/python/pyspark/pandas/tests/connect/series/test_parity_compute.py b/python/pyspark/pandas/tests/connect/series/test_parity_compute.py index f757d19ca6941..8876fcb139885 100644 --- a/python/pyspark/pandas/tests/connect/series/test_parity_compute.py +++ b/python/pyspark/pandas/tests/connect/series/test_parity_compute.py @@ -22,11 +22,7 @@ class SeriesParityComputeTests(SeriesComputeMixin, PandasOnSparkTestUtils, ReusedConnectTestCase): - @unittest.skip( - "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark Connect client." - ) - def test_unstack(self): - super().test_unstack() + pass if __name__ == "__main__": diff --git a/python/pyspark/pandas/tests/connect/test_parity_categorical.py b/python/pyspark/pandas/tests/connect/test_parity_categorical.py index 3e05eb2c0f3b7..210cfce8ddbaf 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_categorical.py +++ b/python/pyspark/pandas/tests/connect/test_parity_categorical.py @@ -53,12 +53,6 @@ def test_reorder_categories(self): def test_set_categories(self): super().test_set_categories() - @unittest.skip( - "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark Connect client." - ) - def test_unstack(self): - super().test_unstack() - if __name__ == "__main__": from pyspark.pandas.tests.connect.test_parity_categorical import * # noqa: F401 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1de745baa0544..6c1d774a1b5fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -759,7 +759,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved) || (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved)) || !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p - case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => + case p @ Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => if (!RowOrdering.isOrderable(pivotColumn.dataType)) { throw QueryCompilationErrors.unorderablePivotColError(pivotColumn) } @@ -823,7 +823,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor Alias(ExtractValue(pivotAtt, Literal(i), resolver), outputName(value, aggregate))() } } - Project(groupByExprsAttr ++ pivotOutputs, secondAgg) + val newProject = Project(groupByExprsAttr ++ pivotOutputs, secondAgg) + newProject.copyTagsFrom(p) + newProject } else { val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => def ifExpr(e: Expression) = { @@ -857,7 +859,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor Alias(filteredAggregate, outputName(value, aggregate))() } } - Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child) + val newAggregate = Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child) + newAggregate.copyTagsFrom(p) + newAggregate } } From 8a8471811727cd4a49696e68713d459417ebca73 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 3 Aug 2023 12:05:02 +0900 Subject: [PATCH 02/68] [SPARK-44643][SQL][PYTHON] Fix Row.__repr__ for the case the field is empty Row ### What changes were proposed in this pull request? Fix `Row.__repr__` for the case the field is empty `Row`. ```py >>> repr(Row(Row())) ')>' ``` ### Why are the changes needed? `Row.__repr__` is broken when it contains an empty `Row`: ```py >>> repr(Row(Row())) Traceback (most recent call last): ... TypeError: not enough arguments for format string ``` ### Does this PR introduce _any_ user-facing change? `Row` that contains an empty `Row` will be shown in the REPL. ### How was this patch tested? Added the related test. Closes #42303 from ueshin/issues/SPARK-44643/repr_row. Authored-by: Takuya UESHIN Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/tests/test_types.py | 9 +++++++++ python/pyspark/sql/types.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 083aa151d0dd8..7cb13693a0df9 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -1323,6 +1323,15 @@ def test_row_without_column_name(self): # test __repr__ with unicode values self.assertEqual(repr(Row("数", "量")), "") + # SPARK-44643: test __repr__ with empty Row + def test_row_repr_with_empty_row(self): + self.assertEqual(repr(Row(a=Row())), "Row(a=)") + self.assertEqual(repr(Row(Row())), ")>") + + EmptyRow = Row() + self.assertEqual(repr(Row(a=EmptyRow())), "Row(a=Row())") + self.assertEqual(repr(Row(EmptyRow())), "") + def test_empty_row(self): row = Row() self.assertEqual(len(row), 0) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index db615d339b5ae..092fa43b1d2e7 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -2402,7 +2402,7 @@ def __repr__(self) -> str: "%s=%r" % (k, v) for k, v in zip(self.__fields__, tuple(self)) ) else: - return "" % ", ".join("%r" % field for field in self) + return "" % ", ".join(repr(field) for field in self) class DateConverter: From b390c9c55a0c16423657ba64f27b9a16ce509c23 Mon Sep 17 00:00:00 2001 From: Richard Yu Date: Thu, 3 Aug 2023 12:07:05 +0900 Subject: [PATCH 03/68] [SPARK-44059][SQL] Add better error messages for SQL named argumnts ### What changes were proposed in this pull request? Correct error messages. ### Why are the changes needed? Need to have better quality messages. ### Does this PR introduce _any_ user-facing change? Error messages are more specific. ### How was this patch tested? Tested in SQLQueryTestSuite: named-function-arguments.sql and NamedParameterFunctionSuite. Authored-by: Richard Yu (cherry picked from commit 228b5dbfd7688a8efa7135d9ec7b00b71e41a38a) Closes #42177 from learningchess2003/error-messages. Lead-authored-by: Richard Yu Co-authored-by: Richard Yu <134337791+learningchess2003@users.noreply.github.com> Signed-off-by: Hyukjin Kwon --- .../main/resources/error/error-classes.json | 8 +++---- ...outine-parameter-assignment-error-class.md | 4 ++-- docs/sql-error-conditions.md | 4 ++-- .../plans/logical/FunctionBuilderBase.scala | 24 ++++++++++--------- .../sql/errors/QueryCompilationErrors.scala | 13 ++++++---- .../NamedParameterFunctionSuite.scala | 11 +++++---- .../named-function-arguments.sql.out | 4 +++- .../results/named-function-arguments.sql.out | 4 +++- 8 files changed, 43 insertions(+), 29 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 063505228340e..a9619b97bd929 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -809,12 +809,12 @@ "subClass" : { "BOTH_POSITIONAL_AND_NAMED" : { "message" : [ - "A positional argument and named argument both referred to the same parameter." + "A positional argument and named argument both referred to the same parameter. Please remove the named argument referring to this parameter." ] }, "DOUBLE_NAMED_ARGUMENT_REFERENCE" : { "message" : [ - "More than one named argument referred to the same parameter." + "More than one named argument referred to the same parameter. Please assign a value only once." ] } }, @@ -2446,7 +2446,7 @@ }, "REQUIRED_PARAMETER_NOT_FOUND" : { "message" : [ - "Cannot invoke function because the parameter named is required, but the function call did not supply a value. Please update the function call to supply an argument value (either positionally or by name) and retry the query again." + "Cannot invoke function because the parameter named is required, but the function call did not supply a value. Please update the function call to supply an argument value (either positionally at index or by name) and retry the query again." ], "sqlState" : "4274K" }, @@ -2647,7 +2647,7 @@ }, "UNEXPECTED_POSITIONAL_ARGUMENT" : { "message" : [ - "Cannot invoke function because it contains positional argument(s) following named argument(s); please rearrange them so the positional arguments come first and then retry the query again." + "Cannot invoke function because it contains positional argument(s) following the named argument assigned to ; please rearrange them so the positional arguments come first and then retry the query again." ], "sqlState" : "4274K" }, diff --git a/docs/sql-error-conditions-duplicate-routine-parameter-assignment-error-class.md b/docs/sql-error-conditions-duplicate-routine-parameter-assignment-error-class.md index d9f14b5a55ef8..eb5ca2a0169d1 100644 --- a/docs/sql-error-conditions-duplicate-routine-parameter-assignment-error-class.md +++ b/docs/sql-error-conditions-duplicate-routine-parameter-assignment-error-class.md @@ -27,10 +27,10 @@ This error class has the following derived error classes: ## BOTH_POSITIONAL_AND_NAMED -A positional argument and named argument both referred to the same parameter. +A positional argument and named argument both referred to the same parameter. Please remove the named argument referring to this parameter. ## DOUBLE_NAMED_ARGUMENT_REFERENCE -More than one named argument referred to the same parameter. +More than one named argument referred to the same parameter. Please assign a value only once. diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index 6ea16d7ef31b3..161f3bdbef121 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -1563,7 +1563,7 @@ The `` clause may be used at most once per `` operation. [SQLSTATE: 4274K](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) -Cannot invoke function `` because the parameter named `` is required, but the function call did not supply a value. Please update the function call to supply an argument value (either positionally or by name) and retry the query again. +Cannot invoke function `` because the parameter named `` is required, but the function call did not supply a value. Please update the function call to supply an argument value (either positionally at index `` or by name) and retry the query again. ### REQUIRES_SINGLE_PART_NAMESPACE @@ -1778,7 +1778,7 @@ Parameter `` of function `` requires the `` because it contains positional argument(s) following named argument(s); please rearrange them so the positional arguments come first and then retry the query again. +Cannot invoke function `` because it contains positional argument(s) following the named argument assigned to ``; please rearrange them so the positional arguments come first and then retry the query again. ### UNKNOWN_PROTOBUF_MESSAGE_TYPE diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala index 4a2b9eae98100..1088655f60cd4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala @@ -104,7 +104,7 @@ object NamedParametersSupport { val positionalParametersSet = allParameterNames.take(positionalArgs.size).toSet val namedParametersSet = collection.mutable.Set[String]() - for (arg <- namedArgs) { + namedArgs.zipWithIndex.foreach { case (arg, index) => arg match { case namedArg: NamedArgumentExpression => val parameterName = namedArg.key @@ -122,7 +122,8 @@ object NamedParametersSupport { } namedParametersSet.add(namedArg.key) case _ => - throw QueryCompilationErrors.unexpectedPositionalArgument(functionName) + throw QueryCompilationErrors.unexpectedPositionalArgument( + functionName, namedArgs(index - 1).asInstanceOf[NamedArgumentExpression].key) } } @@ -141,15 +142,16 @@ object NamedParametersSupport { }.toMap // We rearrange named arguments to match their positional order. - val rearrangedNamedArgs: Seq[Expression] = namedParameters.map { param => - namedArgMap.getOrElse( - param.name, - if (param.default.isEmpty) { - throw QueryCompilationErrors.requiredParameterNotFound(functionName, param.name) - } else { - param.default.get - } - ) + val rearrangedNamedArgs: Seq[Expression] = namedParameters.zipWithIndex.map { + case (param, index) => + namedArgMap.getOrElse( + param.name, + if (param.default.isEmpty) { + throw QueryCompilationErrors.requiredParameterNotFound(functionName, param.name, index) + } else { + param.default.get + } + ) } val rearrangedArgs = positionalArgs ++ rearrangedNamedArgs assert(rearrangedArgs.size == parameters.size) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 41de0c76b3b00..1e4f779e565af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -90,12 +90,13 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat } def requiredParameterNotFound( - functionName: String, parameterName: String) : Throwable = { + functionName: String, parameterName: String, index: Int) : Throwable = { new AnalysisException( errorClass = "REQUIRED_PARAMETER_NOT_FOUND", messageParameters = Map( "functionName" -> toSQLId(functionName), - "parameterName" -> toSQLId(parameterName)) + "parameterName" -> toSQLId(parameterName), + "index" -> index.toString) ) } @@ -115,10 +116,14 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat ) } - def unexpectedPositionalArgument(functionName: String): Throwable = { + def unexpectedPositionalArgument( + functionName: String, + precedingNamedArgument: String): Throwable = { new AnalysisException( errorClass = "UNEXPECTED_POSITIONAL_ARGUMENT", - messageParameters = Map("functionName" -> toSQLId(functionName)) + messageParameters = Map( + "functionName" -> toSQLId(functionName), + "parameterName" -> toSQLId(precedingNamedArgument)) ) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/NamedParameterFunctionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/NamedParameterFunctionSuite.scala index dd5cb5e7d03c8..99fed4d2ee5d9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/NamedParameterFunctionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/NamedParameterFunctionSuite.scala @@ -40,6 +40,7 @@ case class DummyExpression( } object DummyExpressionBuilder extends ExpressionBuilder { + def defaultFunctionSignature: FunctionSignature = { FunctionSignature(Seq(InputParameter("k1"), InputParameter("k2"), @@ -49,11 +50,12 @@ object DummyExpressionBuilder extends ExpressionBuilder { override def functionSignature: Option[FunctionSignature] = Some(defaultFunctionSignature) + override def build(funcName: String, expressions: Seq[Expression]): Expression = DummyExpression(expressions(0), expressions(1), expressions(2), expressions(3)) } -class NamedArgumentFunctionSuite extends AnalysisTest { +class NamedParameterFunctionSuite extends AnalysisTest { final val k1Arg = Literal("v1") final val k2Arg = NamedArgumentExpression("k2", Literal("v2")) @@ -61,6 +63,7 @@ class NamedArgumentFunctionSuite extends AnalysisTest { final val k4Arg = NamedArgumentExpression("k4", Literal("v4")) final val namedK1Arg = NamedArgumentExpression("k1", Literal("v1-2")) final val args = Seq(k1Arg, k4Arg, k2Arg, k3Arg) + final val expectedSeq = Seq(Literal("v1"), Literal("v2"), Literal("v3"), Literal("v4")) final val signature = DummyExpressionBuilder.defaultFunctionSignature final val illegalSignature = FunctionSignature(Seq( @@ -115,8 +118,8 @@ class NamedArgumentFunctionSuite extends AnalysisTest { checkError( exception = parseRearrangeException(signature, Seq(k1Arg, k2Arg, k3Arg), "foo"), errorClass = "REQUIRED_PARAMETER_NOT_FOUND", - parameters = Map("functionName" -> toSQLId("foo"), "parameterName" -> toSQLId("k4")) - ) + parameters = Map( + "functionName" -> toSQLId("foo"), "parameterName" -> toSQLId("k4"), "index" -> "2")) } test("UNRECOGNIZED_PARAMETER_NAME") { @@ -134,7 +137,7 @@ class NamedArgumentFunctionSuite extends AnalysisTest { exception = parseRearrangeException(signature, Seq(k2Arg, k3Arg, k1Arg, k4Arg), "foo"), errorClass = "UNEXPECTED_POSITIONAL_ARGUMENT", - parameters = Map("functionName" -> toSQLId("foo")) + parameters = Map("functionName" -> toSQLId("foo"), "parameterName" -> toSQLId("k3")) ) } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/named-function-arguments.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/named-function-arguments.sql.out index 650b61b419245..11e2651c6f225 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/named-function-arguments.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/named-function-arguments.sql.out @@ -229,7 +229,8 @@ org.apache.spark.sql.AnalysisException "errorClass" : "UNEXPECTED_POSITIONAL_ARGUMENT", "sqlState" : "4274K", "messageParameters" : { - "functionName" : "`mask`" + "functionName" : "`mask`", + "parameterName" : "`lowerChar`" }, "queryContext" : [ { "objectType" : "", @@ -292,6 +293,7 @@ org.apache.spark.sql.AnalysisException "sqlState" : "4274K", "messageParameters" : { "functionName" : "`mask`", + "index" : "0", "parameterName" : "`str`" }, "queryContext" : [ { diff --git a/sql/core/src/test/resources/sql-tests/results/named-function-arguments.sql.out b/sql/core/src/test/resources/sql-tests/results/named-function-arguments.sql.out index 77c15b56c8dab..60301862a35c9 100644 --- a/sql/core/src/test/resources/sql-tests/results/named-function-arguments.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/named-function-arguments.sql.out @@ -214,7 +214,8 @@ org.apache.spark.sql.AnalysisException "errorClass" : "UNEXPECTED_POSITIONAL_ARGUMENT", "sqlState" : "4274K", "messageParameters" : { - "functionName" : "`mask`" + "functionName" : "`mask`", + "parameterName" : "`lowerChar`" }, "queryContext" : [ { "objectType" : "", @@ -283,6 +284,7 @@ org.apache.spark.sql.AnalysisException "sqlState" : "4274K", "messageParameters" : { "functionName" : "`mask`", + "index" : "0", "parameterName" : "`str`" }, "queryContext" : [ { From c607f1843035199a15d277bee03e56ba99da89c9 Mon Sep 17 00:00:00 2001 From: Richard Chen Date: Thu, 3 Aug 2023 12:08:25 +0900 Subject: [PATCH 04/68] [SPARK-44488][SQL] Support deserializing long types when creating `Metadata` object from JObject ### What changes were proposed in this pull request? Adds support to deserialize long types when creating `Metadata` objects from `JObject`s. ### Why are the changes needed? Code will previously crash when adding a `long` type to the `Metadata` object. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? Closes #42083 from richardc-db/add_long_metadata_serialization. Authored-by: Richard Chen Signed-off-by: Hyukjin Kwon --- .../src/main/scala/org/apache/spark/sql/types/Metadata.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/Metadata.scala index 4e7ac996d31e1..3677927b9a555 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -137,6 +137,8 @@ object Metadata { jObj.obj.foreach { case (key, JInt(value)) => builder.putLong(key, value.toLong) + case (key, JLong(value)) => + builder.putLong(key, value.toLong) case (key, JDouble(value)) => builder.putDouble(key, value) case (key, JBool(value)) => @@ -153,6 +155,8 @@ object Metadata { value.head match { case _: JInt => builder.putLongArray(key, value.asInstanceOf[List[JInt]].map(_.num.toLong).toArray) + case _: JLong => + builder.putLongArray(key, value.asInstanceOf[List[JLong]].map(_.num.toLong).toArray) case _: JDouble => builder.putDoubleArray(key, value.asInstanceOf[List[JDouble]].map(_.num).toArray) case _: JBool => From 3b3e30113262455553a0bb7d668b2a7d9a23a05d Mon Sep 17 00:00:00 2001 From: pegasas <616672335@qq.com> Date: Thu, 3 Aug 2023 12:15:19 +0900 Subject: [PATCH 05/68] [SPARK-42730][CONNECT][DOCS] Update Spark Standalone Mode page ### What changes were proposed in this pull request? [SPARK-42730][CONNECT][DOCS] Add start-connect-server.sh/stop-connect-server.sh to this list and cover Spark Connect sessions - other changes needed here. ### Why are the changes needed? [SPARK-42730][CONNECT][DOCS] Add start-connect-server.sh/stop-connect-server.sh to this list and cover Spark Connect sessions - other changes needed here.. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Spark Document related patch tested. Closes #42307 from pegasas/doc. Authored-by: pegasas <616672335@qq.com> Signed-off-by: Hyukjin Kwon --- docs/spark-standalone.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index d47ff3987f95b..3e87edad0aadd 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -104,10 +104,12 @@ Once you've set up this file, you can launch or stop your cluster with the follo - `sbin/start-master.sh` - Starts a master instance on the machine the script is executed on. - `sbin/start-workers.sh` - Starts a worker instance on each machine specified in the `conf/workers` file. - `sbin/start-worker.sh` - Starts a worker instance on the machine the script is executed on. +- `sbin/start-connect-server.sh` - Starts a Spark Connect server on the machine the script is executed on. - `sbin/start-all.sh` - Starts both a master and a number of workers as described above. - `sbin/stop-master.sh` - Stops the master that was started via the `sbin/start-master.sh` script. - `sbin/stop-worker.sh` - Stops all worker instances on the machine the script is executed on. - `sbin/stop-workers.sh` - Stops all worker instances on the machines specified in the `conf/workers` file. +- `sbin/stop-connect-server.sh` - Stops all Spark Connect server instances on the machine the script is executed on. - `sbin/stop-all.sh` - Stops both the master and the workers as described above. Note that these scripts must be executed on the machine you want to run the Spark master on, not your local machine. From 445e3c395221126469394718b7bf190026a862f7 Mon Sep 17 00:00:00 2001 From: Amanda Liu Date: Thu, 3 Aug 2023 12:18:14 +0900 Subject: [PATCH 06/68] [SPARK-44645][PYTHON][DOCS] Update assertDataFrameEqual docs error example output ### What changes were proposed in this pull request? This PR updates the error example output for the `assertDataFrameEqual` docs, given the new error message formatting. ### Why are the changes needed? The change is needed to display the accurate `assertDataFrameEqual` error message. ### Does this PR introduce _any_ user-facing change? Yes, the PR affects the user view for the PySpark docs page. ### How was this patch tested? Existing tests Closes #42305 from asl3/update-docs-error. Authored-by: Amanda Liu Signed-off-by: Hyukjin Kwon --- python/pyspark/testing/utils.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py index 2a23476112fee..5461577fad1de 100644 --- a/python/pyspark/testing/utils.py +++ b/python/pyspark/testing/utils.py @@ -432,17 +432,15 @@ def assertDataFrameEqual( Traceback (most recent call last): ... PySparkAssertionError: [DIFFERENT_ROWS] Results do not match: ( 66.66667 % ) - --- actual - +++ expected - - Row(id='1', amount=1000.0) - ? ^ - + Row(id='1', amount=1001.0) - ? ^ - - Row(id='3', amount=2000.0) - ? ^ - + Row(id='3', amount=2003.0) - ? ^ - + *** actual *** + ! Row(id='1', amount=1000.0) + Row(id='2', amount=3000.0) + ! Row(id='3', amount=2000.0) + + *** expected *** + ! Row(id='1', amount=1001.0) + Row(id='2', amount=3000.0) + ! Row(id='3', amount=2003.0) """ if actual is None and expected is None: return True From fd624530e3b23b5d08cfa1f91f4708d1ad64716e Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 3 Aug 2023 12:05:39 +0800 Subject: [PATCH 07/68] [SPARK-44572][INFRA] Clean up unused installers ASAP ### What changes were proposed in this pull request? Clean up unused installers ASAP ### Why are the changes needed? to free disk space a bit ### Does this PR introduce _any_ user-facing change? no, infra-only ### How was this patch tested? updated CI Closes #42292 from zhengruifeng/infra_packaging_cleanup. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- .github/workflows/build_and_test.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 47c1be1ba863b..d9bcdfcbfa474 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -424,6 +424,7 @@ jobs: run: | curl -s https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh > miniconda.sh bash miniconda.sh -b -p $HOME/miniconda + rm miniconda.sh # Run the tests. - name: Run tests env: ${{ fromJSON(inputs.envs) }} @@ -647,6 +648,7 @@ jobs: curl -LO https://github.com/bufbuild/buf/releases/download/v1.24.0/buf-Linux-x86_64.tar.gz mkdir -p $HOME/buf tar -xvzf buf-Linux-x86_64.tar.gz -C $HOME/buf --strip-components 1 + rm buf-Linux-x86_64.tar.gz python3.9 -m pip install 'protobuf==3.20.3' 'mypy-protobuf==3.3.0' - name: Python code generation check run: if test -f ./dev/connect-check-protos.py; then PATH=$PATH:$HOME/buf/bin PYTHON_EXECUTABLE=python3.9 ./dev/connect-check-protos.py; fi @@ -1027,6 +1029,7 @@ jobs: # TODO(SPARK-44495): Resume to use the latest minikube for k8s-integration-tests. curl -LO https://storage.googleapis.com/minikube/releases/v1.30.1/minikube-linux-amd64 sudo install minikube-linux-amd64 /usr/local/bin/minikube + rm minikube-linux-amd64 # Github Action limit cpu:2, memory: 6947MB, limit to 2U6G for better resource statistic minikube start --cpus 2 --memory 6144 - name: Print K8S pods and nodes info From f824d058b14e3c58b1c90f64fefc45fac105c7dd Mon Sep 17 00:00:00 2001 From: Koray Beyaz Date: Thu, 3 Aug 2023 10:57:26 +0500 Subject: [PATCH 08/68] [SPARK-42330][SQL] Assign the name `RULE_ID_NOT_FOUND` to the error class `_LEGACY_ERROR_TEMP_2175` ### What changes were proposed in this pull request? - Rename _LEGACY_ERROR_TEMP_2175 as RULE_ID_NOT_FOUND - Add a test case for the error class. ### Why are the changes needed? We are migrating onto error classes ### Does this PR introduce _any_ user-facing change? Yes, the error message will include the error class name ### How was this patch tested? `testOnly *RuleIdCollectionSuite` and Github Actions Closes #40991 from kori73/SPARK-42330. Lead-authored-by: Koray Beyaz Co-authored-by: Koray Beyaz Signed-off-by: Max Gekk --- .../utils/src/main/resources/error/error-classes.json | 11 ++++++----- docs/sql-error-conditions.md | 6 ++++++ .../spark/sql/errors/QueryExecutionErrors.scala | 5 ++--- .../spark/sql/errors/QueryExecutionErrorsSuite.scala | 11 +++++++++++ 4 files changed, 25 insertions(+), 8 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index a9619b97bd929..20f2ab4eb24eb 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -2471,6 +2471,12 @@ ], "sqlState" : "42883" }, + "RULE_ID_NOT_FOUND" : { + "message" : [ + "Not found an id for the rule name \"\". Please modify RuleIdCollection.scala if you are adding a new rule." + ], + "sqlState" : "22023" + }, "SCALAR_SUBQUERY_IS_IN_GROUP_BY_OR_AGGREGATE_FUNCTION" : { "message" : [ "The correlated scalar subquery '' is neither present in GROUP BY, nor in an aggregate function. Add it to GROUP BY using ordinal position or wrap it in `first()` (or `first_value`) if you don't care which value you get." @@ -5489,11 +5495,6 @@ "." ] }, - "_LEGACY_ERROR_TEMP_2175" : { - "message" : [ - "Rule id not found for . Please modify RuleIdCollection.scala if you are adding a new rule." - ] - }, "_LEGACY_ERROR_TEMP_2176" : { "message" : [ "Cannot create array with elements of data due to exceeding the limit elements for ArrayData. " diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index 161f3bdbef121..5609d60f97419 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -1586,6 +1586,12 @@ The function `` cannot be found. Verify the spelling and correctnes If you did not qualify the name with a schema and catalog, verify the current_schema() output, or qualify the name with the correct schema and catalog. To tolerate the error on drop use DROP FUNCTION IF EXISTS. +### RULE_ID_NOT_FOUND + +[SQLSTATE: 22023](sql-error-conditions-sqlstates.html#class-22-data-exception) + +Not found an id for the rule name "``". Please modify RuleIdCollection.scala if you are adding a new rule. + ### SCALAR_SUBQUERY_IS_IN_GROUP_BY_OR_AGGREGATE_FUNCTION SQLSTATE: none assigned diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 3622ffebb74d9..45b5d6b6692cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -1584,9 +1584,8 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def ruleIdNotFoundForRuleError(ruleName: String): Throwable = { new SparkException( - errorClass = "_LEGACY_ERROR_TEMP_2175", - messageParameters = Map( - "ruleName" -> ruleName), + errorClass = "RULE_ID_NOT_FOUND", + messageParameters = Map("ruleName" -> ruleName), cause = null) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala index e70d04b7b5a6f..ae1c0a86a14c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.{Grouping, Literal, RowNumber} import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.expressions.objects.InitializeJavaBean +import org.apache.spark.sql.catalyst.rules.RuleIdCollection import org.apache.spark.sql.catalyst.util.BadRecordException import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, JDBCOptions} import org.apache.spark.sql.execution.datasources.jdbc.connection.ConnectionProvider @@ -499,6 +500,16 @@ class QueryExecutionErrorsSuite } } + test("SPARK-42330: rule id not found") { + checkError( + exception = intercept[SparkException] { + RuleIdCollection.getRuleId("incorrect") + }, + errorClass = "RULE_ID_NOT_FOUND", + parameters = Map("ruleName" -> "incorrect") + ) + } + test("CANNOT_RESTORE_PERMISSIONS_FOR_PATH: can't set permission") { withTable("t") { withSQLConf( From ed036a9d0aab2d75b5c0db5caebfc158ce22ec15 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 3 Aug 2023 14:18:16 -0700 Subject: [PATCH 09/68] [SPARK-44658][CORE] `ShuffleStatus.getMapStatus` should return `None` instead of `Some(null)` ### What changes were proposed in this pull request? This PR is for `master` and `branch-3.5` and aims to fix a regression due to SPARK-43043 which landed at Apache Spark 3.4.1 and reverted via SPARK-44630. This PR makes `ShuffleStatus.getMapStatus` return `None` instead of `Some(null)`. ### Why are the changes needed? `None` is better because `Some(null)` is unsafe because it causes NPE in some cases. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs with the newly added test case. Closes #42323 from dongjoon-hyun/SPARK-44658. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../main/scala/org/apache/spark/MapOutputTracker.scala | 5 ++++- .../scala/org/apache/spark/MapOutputTrackerSuite.scala | 9 +++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 47ac3df4cc62c..3495536a3508f 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -171,7 +171,10 @@ private class ShuffleStatus( * Get the map output that corresponding to a given mapId. */ def getMapStatus(mapId: Long): Option[MapStatus] = withReadLock { - mapIdToMapIndex.get(mapId).map(mapStatuses(_)) + mapIdToMapIndex.get(mapId).map(mapStatuses(_)) match { + case Some(null) => None + case m => m + } } /** diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 7ac3d0092c8ce..7ee36137e2715 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -1083,4 +1083,13 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { rpcEnv.shutdown() } } + + test("SPARK-44658: ShuffleStatus.getMapStatus should return None") { + val bmID = BlockManagerId("a", "hostA", 1000) + val mapStatus = MapStatus(bmID, Array(1000L, 10000L), mapTaskId = 0) + val shuffleStatus = new ShuffleStatus(1000) + shuffleStatus.addMapOutput(mapIndex = 1, mapStatus) + shuffleStatus.removeMapOutput(mapIndex = 1, bmID) + assert(shuffleStatus.getMapStatus(0).isEmpty) + } } From 9fbf0b4853c6209675daa0731f8b33a83b2f5cef Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 3 Aug 2023 14:40:59 -0700 Subject: [PATCH 10/68] [SPARK-44661][CORE][TESTS] `getMapOutputLocation` should not throw NPE ### What changes were proposed in this pull request? This PR aims to add a test coverage for Apache Spark 4.0/3.5/3.4. This PR depends on SPARK-44658 (#42323) but is created separately because this aims to land `branch-3.4` too. ### Why are the changes needed? To prevent a future regression. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs. Closes #42326 from dongjoon-hyun/SPARK-44661. Lead-authored-by: Dongjoon Hyun Co-authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../apache/spark/MapOutputTrackerSuite.scala | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 7ee36137e2715..450ff01921a83 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -1092,4 +1092,21 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { shuffleStatus.removeMapOutput(mapIndex = 1, bmID) assert(shuffleStatus.getMapStatus(0).isEmpty) } + + test("SPARK-44661: getMapOutputLocation should not throw NPE") { + val rpcEnv = createRpcEnv("test") + val tracker = newTrackerMaster() + try { + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + tracker.registerShuffle(0, 1, 1) + tracker.registerMapOutput(0, 0, MapStatus(BlockManagerId("exec-1", "hostA", 1000), + Array(2L), 0)) + tracker.removeOutputsOnHost("hostA") + assert(tracker.getMapOutputLocation(0, 0) == None) + } finally { + tracker.stop() + rpcEnv.shutdown() + } + } } From 16b031eb144f6ba1c1103be5dcf00d6209adaa85 Mon Sep 17 00:00:00 2001 From: Amanda Liu Date: Fri, 4 Aug 2023 08:42:39 +0900 Subject: [PATCH 11/68] [SPARK-44652] Raise error when only one df is None ### What changes were proposed in this pull request? Adds a "raise PySparkAssertionError" for the case when one of `actual` or `expected` is None, instead of just returning False. ### Why are the changes needed? The PR ensures that an error is thrown in the assertion for the edge case when one of `actual` or `expected` is None ### Does this PR introduce _any_ user-facing change? Yes, the PR affects the user-facing API `assertDataFrameEqual` ### How was this patch tested? Added tests to `python/pyspark/sql/tests/test_utils.py` Closes #42314 from asl3/raise-none-error. Authored-by: Amanda Liu Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/tests/test_utils.py | 82 +++++++++++++++++++++++--- python/pyspark/testing/utils.py | 32 ++++++++-- 2 files changed, 99 insertions(+), 15 deletions(-) diff --git a/python/pyspark/sql/tests/test_utils.py b/python/pyspark/sql/tests/test_utils.py index 76d397e3adeb8..93895465de7f7 100644 --- a/python/pyspark/sql/tests/test_utils.py +++ b/python/pyspark/sql/tests/test_utils.py @@ -41,6 +41,7 @@ BooleanType, ) from pyspark.sql.dataframe import DataFrame +import pyspark.pandas as ps import difflib from typing import List, Union @@ -672,9 +673,79 @@ def test_assert_equal_nulldf(self): assertDataFrameEqual(df1, df2, checkRowOrder=False) assertDataFrameEqual(df1, df2, checkRowOrder=True) - def test_assert_equal_exact_pandas_df(self): - import pyspark.pandas as ps + def test_assert_unequal_null_actual(self): + df1 = None + df2 = self.spark.createDataFrame( + data=[ + ("1", 1000), + ("2", 3000), + ], + schema=["id", "amount"], + ) + + with self.assertRaises(PySparkAssertionError) as pe: + assertDataFrameEqual(df1, df2) + + self.check_error( + exception=pe.exception, + error_class="INVALID_TYPE_DF_EQUALITY_ARG", + message_parameters={ + "expected_type": Union[DataFrame, ps.DataFrame, List[Row]], + "arg_name": "actual", + "actual_type": None, + }, + ) + + with self.assertRaises(PySparkAssertionError) as pe: + assertDataFrameEqual(df1, df2, checkRowOrder=True) + + self.check_error( + exception=pe.exception, + error_class="INVALID_TYPE_DF_EQUALITY_ARG", + message_parameters={ + "expected_type": Union[DataFrame, ps.DataFrame, List[Row]], + "arg_name": "actual", + "actual_type": None, + }, + ) + + def test_assert_unequal_null_expected(self): + df1 = self.spark.createDataFrame( + data=[ + ("1", 1000), + ("2", 3000), + ], + schema=["id", "amount"], + ) + df2 = None + + with self.assertRaises(PySparkAssertionError) as pe: + assertDataFrameEqual(df1, df2) + self.check_error( + exception=pe.exception, + error_class="INVALID_TYPE_DF_EQUALITY_ARG", + message_parameters={ + "expected_type": Union[DataFrame, ps.DataFrame, List[Row]], + "arg_name": "expected", + "actual_type": None, + }, + ) + + with self.assertRaises(PySparkAssertionError) as pe: + assertDataFrameEqual(df1, df2, checkRowOrder=True) + + self.check_error( + exception=pe.exception, + error_class="INVALID_TYPE_DF_EQUALITY_ARG", + message_parameters={ + "expected_type": Union[DataFrame, ps.DataFrame, List[Row]], + "arg_name": "expected", + "actual_type": None, + }, + ) + + def test_assert_equal_exact_pandas_df(self): df1 = ps.DataFrame(data=[10, 20, 30], columns=["Numbers"]) df2 = ps.DataFrame(data=[10, 20, 30], columns=["Numbers"]) @@ -682,16 +753,12 @@ def test_assert_equal_exact_pandas_df(self): assertDataFrameEqual(df1, df2, checkRowOrder=True) def test_assert_equal_exact_pandas_df(self): - import pyspark.pandas as ps - df1 = ps.DataFrame(data=[10, 20, 30], columns=["Numbers"]) df2 = ps.DataFrame(data=[30, 20, 10], columns=["Numbers"]) assertDataFrameEqual(df1, df2) def test_assert_equal_approx_pandas_df(self): - import pyspark.pandas as ps - df1 = ps.DataFrame(data=[10.0001, 20.32, 30.1], columns=["Numbers"]) df2 = ps.DataFrame(data=[10.0, 20.32, 30.1], columns=["Numbers"]) @@ -699,7 +766,6 @@ def test_assert_equal_approx_pandas_df(self): assertDataFrameEqual(df1, df2, checkRowOrder=True) def test_assert_error_pandas_pyspark_df(self): - import pyspark.pandas as ps import pandas as pd df1 = ps.DataFrame(data=[10, 20, 30], columns=["Numbers"]) @@ -742,8 +808,6 @@ def test_assert_error_pandas_pyspark_df(self): ) def test_assert_error_non_pyspark_df(self): - import pyspark.pandas as ps - dict1 = {"a": 1, "b": 2} dict2 = {"a": 1, "b": 2} diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py index 5461577fad1de..8e02803efe5cb 100644 --- a/python/pyspark/testing/utils.py +++ b/python/pyspark/testing/utils.py @@ -292,6 +292,7 @@ def assertSchemaEqual(actual: StructType, expected: StructType): >>> s1 = StructType([StructField("names", ArrayType(DoubleType(), True), True)]) >>> s2 = StructType([StructField("names", ArrayType(DoubleType(), True), True)]) >>> assertSchemaEqual(s1, s2) # pass, schemas are identical + >>> df1 = spark.createDataFrame(data=[(1, 1000), (2, 3000)], schema=["id", "number"]) >>> df2 = spark.createDataFrame(data=[("1", 1000), ("2", 5000)], schema=["id", "amount"]) >>> assertSchemaEqual(df1.schema, df2.schema) # doctest: +IGNORE_EXCEPTION_DETAIL @@ -414,16 +415,20 @@ def assertDataFrameEqual( >>> df1 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], schema=["id", "amount"]) >>> df2 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], schema=["id", "amount"]) >>> assertDataFrameEqual(df1, df2) # pass, DataFrames are identical + >>> df1 = spark.createDataFrame(data=[("1", 0.1), ("2", 3.23)], schema=["id", "amount"]) >>> df2 = spark.createDataFrame(data=[("1", 0.109), ("2", 3.23)], schema=["id", "amount"]) >>> assertDataFrameEqual(df1, df2, rtol=1e-1) # pass, DataFrames are approx equal by rtol + >>> df1 = spark.createDataFrame(data=[(1, 1000), (2, 3000)], schema=["id", "amount"]) >>> list_of_rows = [Row(1, 1000), Row(2, 3000)] >>> assertDataFrameEqual(df1, list_of_rows) # pass, actual and expected data are equal + >>> import pyspark.pandas as ps >>> df1 = ps.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]}) >>> df2 = ps.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]}) >>> assertDataFrameEqual(df1, df2) # pass, pandas-on-Spark DataFrames are equal + >>> df1 = spark.createDataFrame( ... data=[("1", 1000.00), ("2", 3000.00), ("3", 2000.00)], schema=["id", "amount"]) >>> df2 = spark.createDataFrame( @@ -436,20 +441,35 @@ def assertDataFrameEqual( ! Row(id='1', amount=1000.0) Row(id='2', amount=3000.0) ! Row(id='3', amount=2000.0) - *** expected *** ! Row(id='1', amount=1001.0) Row(id='2', amount=3000.0) ! Row(id='3', amount=2003.0) """ - if actual is None and expected is None: - return True - elif actual is None or expected is None: - return False - import pyspark.pandas as ps from pyspark.testing.pandasutils import assertPandasOnSparkEqual + if actual is None and expected is None: + return True + elif actual is None: + raise PySparkAssertionError( + error_class="INVALID_TYPE_DF_EQUALITY_ARG", + message_parameters={ + "expected_type": Union[DataFrame, ps.DataFrame, List[Row]], + "arg_name": "actual", + "actual_type": None, + }, + ) + elif expected is None: + raise PySparkAssertionError( + error_class="INVALID_TYPE_DF_EQUALITY_ARG", + message_parameters={ + "expected_type": Union[DataFrame, ps.DataFrame, List[Row]], + "arg_name": "expected", + "actual_type": None, + }, + ) + try: # If Spark Connect dependencies are available, allow Spark Connect DataFrame from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame From d2d43b888aebbb5d4099faec26b076ef390890ce Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Fri, 4 Aug 2023 08:45:44 +0900 Subject: [PATCH 12/68] [SPARK-44642][CONNECT] ReleaseExecute in ExecutePlanResponseReattachableIterator after it gets error from server ### What changes were proposed in this pull request? Client: When server returns error on the response stream via onError, the ExecutePlanResponseReattachableIterator will not see the stream finish with a ResultsComplete. Instead, a StatusRuntimeException will be thrown from next() or hasNext(). Handle catching that exception, telling the server to ReleaseExecute when we receive it, and rethrow it to the user. Server: We also have to tweak the behaviour of ReleaseAll to also interrupt the query. The previous behaviour that in case of a running query one has to first send an interrupt, and then release was done to prevent race conditions of an interrupt coming after ResultComplete. Now, this has been resolved with proper synchronization at the final moments of execution in ExecuteThreadRunner, and as we want the release to be async, having one ReleaseExecutel vs. needing a combination of Interrupt+ReleaseExecute simplifies things. ### Why are the changes needed? If ReleaseExecute is not called by the client to acknowledge that the error was received, the execution will keep dangling on the server until cleaned up by timeout. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Automated tests will come with https://issues.apache.org/jira/browse/SPARK-44625. Closes #42304 from juliuszsompolski/SPARK-44642. Authored-by: Juliusz Sompolski Signed-off-by: Hyukjin Kwon --- ...cutePlanResponseReattachableIterator.scala | 120 +++++++++++------- .../main/protobuf/spark/connect/base.proto | 4 +- .../execution/ExecuteThreadRunner.scala | 31 +++-- .../sql/connect/service/ExecuteHolder.scala | 6 +- .../planner/SparkConnectServiceSuite.scala | 25 +++- python/pyspark/sql/connect/proto/base_pb2.pyi | 4 +- 6 files changed, 116 insertions(+), 74 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala index 008b3c3dd5c71..fc07deaa081f8 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala @@ -102,28 +102,33 @@ class ExecutePlanResponseReattachableIterator( throw new java.util.NoSuchElementException() } - // Get next response, possibly triggering reattach in case of stream error. - var firstTry = true - val ret = retry { - if (firstTry) { - // on first try, we use the existing iterator. - firstTry = false - } else { - // on retry, the iterator is borked, so we need a new one - iterator = rawBlockingStub.reattachExecute(createReattachExecuteRequest()) + try { + // Get next response, possibly triggering reattach in case of stream error. + var firstTry = true + val ret = retry { + if (firstTry) { + // on first try, we use the existing iterator. + firstTry = false + } else { + // on retry, the iterator is borked, so we need a new one + iterator = rawBlockingStub.reattachExecute(createReattachExecuteRequest()) + } + iterator.next() } - iterator.next() - } - // Record last returned response, to know where to restart in case of reattach. - lastReturnedResponseId = Some(ret.getResponseId) - if (ret.hasResultComplete) { - resultComplete = true - releaseExecute(None) // release all - } else { - releaseExecute(lastReturnedResponseId) // release until this response + // Record last returned response, to know where to restart in case of reattach. + lastReturnedResponseId = Some(ret.getResponseId) + if (ret.hasResultComplete) { + releaseAll() + } else { + releaseUntil(lastReturnedResponseId.get) + } + ret + } catch { + case NonFatal(ex) => + releaseAll() // ReleaseExecute on server after error. + throw ex } - ret } override def hasNext(): Boolean = synchronized { @@ -132,47 +137,64 @@ class ExecutePlanResponseReattachableIterator( return false } var firstTry = true - retry { - if (firstTry) { - // on first try, we use the existing iterator. - firstTry = false - } else { - // on retry, the iterator is borked, so we need a new one - iterator = rawBlockingStub.reattachExecute(createReattachExecuteRequest()) - } - var hasNext = iterator.hasNext() - // Graceful reattach: - // If iterator ended, but there was no ResultComplete, it means that there is more, - // and we need to reattach. - if (!hasNext && !resultComplete) { - do { + try { + retry { + if (firstTry) { + // on first try, we use the existing iterator. + firstTry = false + } else { + // on retry, the iterator is borked, so we need a new one iterator = rawBlockingStub.reattachExecute(createReattachExecuteRequest()) - assert(!resultComplete) // shouldn't change... - hasNext = iterator.hasNext() - // It's possible that the new iterator will be empty, so we need to loop to get another. - // Eventually, there will be a non empty iterator, because there's always a ResultComplete - // at the end of the stream. - } while (!hasNext) + } + var hasNext = iterator.hasNext() + // Graceful reattach: + // If iterator ended, but there was no ResultComplete, it means that there is more, + // and we need to reattach. + if (!hasNext && !resultComplete) { + do { + iterator = rawBlockingStub.reattachExecute(createReattachExecuteRequest()) + assert(!resultComplete) // shouldn't change... + hasNext = iterator.hasNext() + // It's possible that the new iterator will be empty, so we need to loop to get another. + // Eventually, there will be a non empty iterator, because there is always a + // ResultComplete inserted by the server at the end of the stream. + } while (!hasNext) + } + hasNext } - hasNext + } catch { + case NonFatal(ex) => + releaseAll() // ReleaseExecute on server after error. + throw ex } } /** - * Inform the server to release the execution. + * Inform the server to release the buffered execution results until and including given result. * * This will send an asynchronous RPC which will not block this iterator, the iterator can * continue to be consumed. + */ + private def releaseUntil(untilResponseId: String): Unit = { + if (!resultComplete) { + val request = createReleaseExecuteRequest(Some(untilResponseId)) + rawAsyncStub.releaseExecute(request, createRetryingReleaseExecuteResponseObserer(request)) + } + } + + /** + * Inform the server to release the execution, either because all results were consumed, or the + * execution finished with error and the error was received. * - * Release with untilResponseId informs the server that the iterator has been consumed until and - * including response with that responseId, and these responses can be freed. - * - * Release with None means that the responses have been completely consumed and informs the - * server that the completed execution can be completely freed. + * This will send an asynchronous RPC which will not block this. The client continues executing, + * and if the release fails, server is equipped to deal with abandoned executions. */ - private def releaseExecute(untilResponseId: Option[String]): Unit = { - val request = createReleaseExecuteRequest(untilResponseId) - rawAsyncStub.releaseExecute(request, createRetryingReleaseExecuteResponseObserer(request)) + private def releaseAll(): Unit = { + if (!resultComplete) { + val request = createReleaseExecuteRequest(None) + rawAsyncStub.releaseExecute(request, createRetryingReleaseExecuteResponseObserer(request)) + resultComplete = true + } } /** diff --git a/connector/connect/common/src/main/protobuf/spark/connect/base.proto b/connector/connect/common/src/main/protobuf/spark/connect/base.proto index 151e828b3e903..79dbadba5bb07 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto @@ -750,9 +750,7 @@ message ReleaseExecuteRequest { optional string client_type = 4; // Release and close operation completely. - // Note: This should be called when the server side operation is finished, and ExecutePlan or - // ReattachExecute are finished processing the result stream, or inside onComplete / onError. - // This will not interrupt a running execution, but block until it's finished. + // This will also interrupt the query if it is running execution, and wait for it to be torn down. message ReleaseAll {} // Release all responses from the operation response stream up to and including diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala index 662288177dc69..930ccae5d4c76 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala @@ -46,6 +46,8 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends private var completed: Boolean = false + private val lock = new Object + /** Launches the execution in a background thread, returns immediately. */ def start(): Unit = { executionThread.start() @@ -62,7 +64,7 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends * true if it was not interrupted before, false if it was already interrupted or completed. */ def interrupt(): Boolean = { - synchronized { + lock.synchronized { if (!interrupted && !completed) { // checking completed prevents sending interrupt onError after onCompleted interrupted = true @@ -119,7 +121,7 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends // Inner executeInternal is wrapped by execute() for error handling. private def executeInternal() = { // synchronized - check if already got interrupted while starting. - synchronized { + lock.synchronized { if (interrupted) { throw new InterruptedException() } @@ -160,14 +162,23 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends s"${executeHolder.request.getPlan.getOpTypeCase} not supported.") } - if (executeHolder.reattachable) { - // Reattachable execution sends a ResultComplete at the end of the stream - // to signal that there isn't more coming. - executeHolder.responseObserver.onNext(createResultComplete()) - } - synchronized { - // Prevent interrupt after onCompleted, and throwing error to an already closed stream. - completed = true + lock.synchronized { + // Synchronized before sending ResultComplete, and up until completing the result stream + // to prevent a situation in which a client of reattachable execution receives + // ResultComplete, and proceeds to send ReleaseExecute, and that triggers an interrupt + // before it finishes. + + if (interrupted) { + // check if it got interrupted at the very last moment + throw new InterruptedException() + } + completed = true // no longer interruptible + + if (executeHolder.reattachable) { + // Reattachable execution sends a ResultComplete at the end of the stream + // to signal that there isn't more coming. + executeHolder.responseObserver.onNext(createResultComplete()) + } executeHolder.responseObserver.onCompleted() } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala index a49c0a8bacf98..4eb90f9f1639a 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala @@ -156,11 +156,11 @@ private[connect] class ExecuteHolder( } /** - * Close the execution and remove it from the session. Note: It blocks joining the - * ExecuteThreadRunner thread, so it assumes that it's called when the execution is ending or - * ended. If it is desired to kill the execution, interrupt() should be called first. + * Close the execution and remove it from the session. Note: it first interrupts the runner if + * it's still running, and it waits for it to finish. */ def close(): Unit = { + runner.interrupt() runner.join() eventsManager.postClosed() sessionHolder.removeExecuteHolder(operationId) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index c29a9b9b62958..e833d12c4f595 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -493,24 +493,37 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with .setSessionId(sessionId) .build() - // The observer is executed inside this thread. So - // we can perform the checks inside the observer. + // Even though the observer is executed inside this thread, this thread is also executing + // the SparkConnectService. If we throw an exception inside it, it will be caught by + // the ErrorUtils.handleError wrapping instance.executePlan and turned into an onError + // call with StatusRuntimeException, which will be eaten here. + var failures: mutable.ArrayBuffer[String] = new mutable.ArrayBuffer[String]() instance.executePlan( request, new StreamObserver[proto.ExecutePlanResponse] { override def onNext(v: proto.ExecutePlanResponse): Unit = { - fail("this should not receive responses") + // The query receives some pre-execution responses such as schema, but should + // never proceed to execution and get query results. + if (v.hasArrowBatch) { + failures += s"this should not receive query results but got $v" + } } override def onError(throwable: Throwable): Unit = { - assert(throwable.isInstanceOf[StatusRuntimeException]) - verifyEvents.onError(throwable) + try { + assert(throwable.isInstanceOf[StatusRuntimeException]) + verifyEvents.onError(throwable) + } catch { + case t: Throwable => + failures += s"assertion $t validating processing onError($throwable)." + } } override def onCompleted(): Unit = { - fail("this should not complete") + failures += "this should not complete" } }) + assert(failures.isEmpty, s"this should have no failures but got $failures") verifyEvents.onCompleted() } } diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index e870221594c13..a886ecbd61842 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -2554,9 +2554,7 @@ class ReleaseExecuteRequest(google.protobuf.message.Message): class ReleaseAll(google.protobuf.message.Message): """Release and close operation completely. - Note: This should be called when the server side operation is finished, and ExecutePlan or - ReattachExecute are finished processing the result stream, or inside onComplete / onError. - This will not interrupt a running execution, but block until it's finished. + This will also interrupt the query if it is running execution, and wait for it to be torn down. """ DESCRIPTOR: google.protobuf.descriptor.Descriptor From 52a9002fa2383bd9b26c77e62e0c6bcd46f8944b Mon Sep 17 00:00:00 2001 From: Sergii Druzkin <65374769+sdruzkin@users.noreply.github.com> Date: Thu, 3 Aug 2023 18:52:44 -0500 Subject: [PATCH 13/68] [MINOR][DOC] Fix a typo in ResolveReferencesInUpdate scaladoc ### What changes were proposed in this pull request? Fixed a typo in the ResolveReferencesInUpdate documentation. ### Why are the changes needed? ### Does this PR introduce any user-facing change? No ### How was this patch tested? CI Closes #42322 from sdruzkin/master. Authored-by: Sergii Druzkin <65374769+sdruzkin@users.noreply.github.com> Signed-off-by: Sean Owen --- .../spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala index cebc1e25f9213..ead323ce9857b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.errors.QueryCompilationErrors /** * A virtual rule to resolve [[UnresolvedAttribute]] in [[UpdateTable]]. It's only used by the real * rule `ResolveReferences`. The column resolution order for [[UpdateTable]] is: - * 1. Resolves the column to `AttributeReference`` with the output of the child plan. This + * 1. Resolves the column to `AttributeReference` with the output of the child plan. This * includes metadata columns as well. * 2. Resolves the column to a literal function which is allowed to be invoked without braces, e.g. * `SELECT col, current_date FROM t`. From 5c36c58047724885864cb781f17038a6b9c94513 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Fri, 4 Aug 2023 09:14:05 +0900 Subject: [PATCH 14/68] [SPARK-44433][PYTHON][CONNECT][SS][FOLLOWUP] Terminate listener process with `removeListener` and improvements ### What changes were proposed in this pull request? This is a followup to #42116. It addresses the following issues: 1. When `removeListener` is called upon one listener, before the python process is left running, now it also get stopped. 2. When multiple `removeListener` is called on the same listener, in non-connect mode, subsequent calls will be noop. But before this PR, in connect it actually throws an error, which doesn't align with existing behavior, this PR addresses it. 3. Set the socket timeout to be None (\infty) for `foreachBatch_worker` and `listener_worker`, because there could be a long time between each microbatch. If not setting this, the socket will timeout and won't be able to process new data. ``` scala> Streaming query listener worker is starting with url sc://localhost:15002/;user_id=wei.liu and sessionId 886191f0-2b64-4c44-b067-de511f04b42d. Traceback (most recent call last): File "/usr/lib/python3.9/runpy.py", line 197, in _run_module_as_main return _run_code(code, main_globals, None, File "/usr/lib/python3.9/runpy.py", line 87, in _run_code exec(code, run_globals) File "/home/wei.liu/oss-spark/python/lib/pyspark.zip/pyspark/sql/connect/streaming/worker/listener_worker.py", line 95, in File "/home/wei.liu/oss-spark/python/lib/pyspark.zip/pyspark/sql/connect/streaming/worker/listener_worker.py", line 82, in main File "/home/wei.liu/oss-spark/python/lib/pyspark.zip/pyspark/serializers.py", line 557, in loads File "/home/wei.liu/oss-spark/python/lib/pyspark.zip/pyspark/serializers.py", line 594, in read_int File "/usr/lib/python3.9/socket.py", line 704, in readinto return self._sock.recv_into(b) socket.timeout: timed out ``` ### Why are the changes needed? Necessary improvements ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Manual test + unit test Closes #42283 from WweiL/SPARK-44433-listener-process-termination. Authored-by: Wei Liu Signed-off-by: Hyukjin Kwon --- .../streaming/StreamingQueryListener.scala | 28 --------------- .../connect/planner/SparkConnectPlanner.scala | 12 ++++--- .../planner/StreamingForeachBatchHelper.scala | 10 +++--- .../StreamingQueryListenerHelper.scala | 21 ++++++----- .../sql/connect/service/SessionHolder.scala | 19 +++++----- .../api/python/StreamingPythonRunner.scala | 36 ++++++++++++++----- .../streaming/worker/foreachBatch_worker.py | 4 ++- .../streaming/worker/listener_worker.py | 4 ++- .../connect/streaming/test_parity_listener.py | 7 ++++ 9 files changed, 77 insertions(+), 64 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala index e2f3be02ad3ae..404bd1b078ba4 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala @@ -75,34 +75,6 @@ abstract class StreamingQueryListener extends Serializable { def onQueryTerminated(event: QueryTerminatedEvent): Unit } -/** - * Py4J allows a pure interface so this proxy is required. - */ -private[spark] trait PythonStreamingQueryListener { - import StreamingQueryListener._ - - def onQueryStarted(event: QueryStartedEvent): Unit - - def onQueryProgress(event: QueryProgressEvent): Unit - - def onQueryIdle(event: QueryIdleEvent): Unit - - def onQueryTerminated(event: QueryTerminatedEvent): Unit -} - -private[spark] class PythonStreamingQueryListenerWrapper(listener: PythonStreamingQueryListener) - extends StreamingQueryListener { - import StreamingQueryListener._ - - def onQueryStarted(event: QueryStartedEvent): Unit = listener.onQueryStarted(event) - - def onQueryProgress(event: QueryProgressEvent): Unit = listener.onQueryProgress(event) - - override def onQueryIdle(event: QueryIdleEvent): Unit = listener.onQueryIdle(event) - - def onQueryTerminated(event: QueryTerminatedEvent): Unit = listener.onQueryTerminated(event) -} - /** * Companion object of [[StreamingQueryListener]] that defines the listener events. * @since 3.5.0 diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index f4b33ae961a2f..7136476b515f9 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -3097,10 +3097,14 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { case StreamingQueryManagerCommand.CommandCase.REMOVE_LISTENER => val listenerId = command.getRemoveListener.getId - val listener: StreamingQueryListener = sessionHolder.getListenerOrThrow(listenerId) - session.streams.removeListener(listener) - sessionHolder.removeCachedListener(listenerId) - respBuilder.setRemoveListener(true) + sessionHolder.getListener(listenerId) match { + case Some(listener) => + session.streams.removeListener(listener) + sessionHolder.removeCachedListener(listenerId) + respBuilder.setRemoveListener(true) + case None => + respBuilder.setRemoveListener(false) + } case StreamingQueryManagerCommand.CommandCase.LIST_LISTENERS => respBuilder.getListListenersBuilder diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala index 998faf327d03a..4f1037b86c9f2 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala @@ -87,11 +87,13 @@ object StreamingForeachBatchHelper extends Logging { val port = SparkConnectService.localPort val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}" - val runner = StreamingPythonRunner(pythonFn, connectUrl) + val runner = StreamingPythonRunner( + pythonFn, + connectUrl, + sessionHolder.sessionId, + "pyspark.sql.connect.streaming.worker.foreachBatch_worker") val (dataOut, dataIn) = - runner.init( - sessionHolder.sessionId, - "pyspark.sql.connect.streaming.worker.foreachBatch_worker") + runner.init() val foreachBatchRunnerFn: FnArgsWithId => Unit = (args: FnArgsWithId) => { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala index d915bc9349609..9b2a931ec4acb 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.streaming.StreamingQueryListener /** * A helper class for handling StreamingQueryListener related functionality in Spark Connect. Each * instance of this class starts a python process, inside which has the python handling logic. - * When new a event is received, it is serialized to json, and passed to the python process. + * When a new event is received, it is serialized to json, and passed to the python process. */ class PythonStreamingQueryListener( listener: SimplePythonFunction, @@ -32,12 +32,15 @@ class PythonStreamingQueryListener( pythonExec: String) extends StreamingQueryListener { - val port = SparkConnectService.localPort - val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}" - val runner = StreamingPythonRunner(listener, connectUrl) + private val port = SparkConnectService.localPort + private val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}" + private val runner = StreamingPythonRunner( + listener, + connectUrl, + sessionHolder.sessionId, + "pyspark.sql.connect.streaming.worker.listener_worker") - val (dataOut, _) = - runner.init(sessionHolder.sessionId, "pyspark.sql.connect.streaming.worker.listener_worker") + val (dataOut, _) = runner.init() override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = { PythonRDD.writeUTF(event.json, dataOut) @@ -63,7 +66,7 @@ class PythonStreamingQueryListener( dataOut.flush() } - // TODO(SPARK-44433)(SPARK-44516): Improve termination of Processes. - // Similar to foreachBatch when we need to exit the process when the query ends. - // In listener semantics, we need to exit the process when removeListener is called. + private[spark] def stopListenerProcess(): Unit = { + runner.stop() + } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 310bb9208c21d..29134f0dc0ded 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager import org.apache.spark.sql.connect.common.InvalidPlanInput +import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener import org.apache.spark.sql.streaming.StreamingQueryListener import org.apache.spark.util.{SystemClock} import org.apache.spark.util.Utils @@ -220,20 +221,22 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio } /** - * Returns [[StreamingQueryListener]] cached for Listener ID `id`. If it is not found, throw - * [[InvalidPlanInput]]. + * Returns [[StreamingQueryListener]] cached for Listener ID `id`. If it is not found, return + * None. */ - private[connect] def getListenerOrThrow(id: String): StreamingQueryListener = { + private[connect] def getListener(id: String): Option[StreamingQueryListener] = { Option(listenerCache.get(id)) - .getOrElse { - throw InvalidPlanInput(s"No listener with id $id is found in the session $sessionId") - } } /** - * Removes corresponding StreamingQueryListener by ID. + * Removes corresponding StreamingQueryListener by ID. Terminates the python process if it's a + * Spark Connect PythonStreamingQueryListener. */ - private[connect] def removeCachedListener(id: String): StreamingQueryListener = { + private[connect] def removeCachedListener(id: String): Unit = { + listenerCache.get(id) match { + case pyListener: PythonStreamingQueryListener => pyListener.stopListenerProcess() + case _ => // do nothing + } listenerCache.remove(id) } diff --git a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala index d4fd9485675fa..f14289f984a2f 100644 --- a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala @@ -29,27 +29,36 @@ import org.apache.spark.internal.config.Python.{PYTHON_AUTH_SOCKET_TIMEOUT, PYTH private[spark] object StreamingPythonRunner { - def apply(func: PythonFunction, connectUrl: String): StreamingPythonRunner = { - new StreamingPythonRunner(func, connectUrl) + def apply( + func: PythonFunction, + connectUrl: String, + sessionId: String, + workerModule: String + ): StreamingPythonRunner = { + new StreamingPythonRunner(func, connectUrl, sessionId, workerModule) } } -private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: String) - extends Logging { +private[spark] class StreamingPythonRunner( + func: PythonFunction, + connectUrl: String, + sessionId: String, + workerModule: String) extends Logging { private val conf = SparkEnv.get.conf protected val bufferSize: Int = conf.get(BUFFER_SIZE) protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT) private val envVars: java.util.Map[String, String] = func.envVars private val pythonExec: String = func.pythonExec + private var pythonWorker: Option[Socket] = None protected val pythonVer: String = func.pythonVer /** * Initializes the Python worker for streaming functions. Sets up Spark Connect session * to be used with the functions. */ - def init(sessionId: String, workerModule: String): (DataOutputStream, DataInputStream) = { - logInfo(s"Initializing Python runner (session: $sessionId ,pythonExec: $pythonExec") + def init(): (DataOutputStream, DataInputStream) = { + logInfo(s"Initializing Python runner (session: $sessionId, pythonExec: $pythonExec") val env = SparkEnv.get val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") @@ -60,9 +69,9 @@ private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: Str conf.set(PYTHON_USE_DAEMON, false) envVars.put("SPARK_CONNECT_LOCAL_URL", connectUrl) - val pythonWorkerFactory = - new PythonWorkerFactory(pythonExec, workerModule, envVars.asScala.toMap) - val (worker: Socket, _) = pythonWorkerFactory.createSimpleWorker() + val (worker, _) = env.createPythonWorker( + pythonExec, workerModule, envVars.asScala.toMap) + pythonWorker = Some(worker) val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) @@ -85,4 +94,13 @@ private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: Str (dataOut, dataIn) } + + /** + * Stops the Python worker. + */ + def stop(): Unit = { + pythonWorker.foreach { worker => + SparkEnv.get.destroyPythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker) + } + } } diff --git a/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py index 054788539f293..48a9848de4009 100644 --- a/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py +++ b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py @@ -76,7 +76,9 @@ def process(df_id, batch_id): # type: ignore[no-untyped-def] # Read information about how to connect back to the JVM from the environment. java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] - (sock_file, _) = local_connect_and_auth(java_port, auth_secret) + (sock_file, sock) = local_connect_and_auth(java_port, auth_secret) + # There could be a long time between each micro batch. + sock.settimeout(None) write_int(os.getpid(), sock_file) sock_file.flush() main(sock_file, sock_file) diff --git a/python/pyspark/sql/connect/streaming/worker/listener_worker.py b/python/pyspark/sql/connect/streaming/worker/listener_worker.py index 8eb310461b6f6..7aef911426de7 100644 --- a/python/pyspark/sql/connect/streaming/worker/listener_worker.py +++ b/python/pyspark/sql/connect/streaming/worker/listener_worker.py @@ -89,7 +89,9 @@ def process(listener_event_str, listener_event_type): # type: ignore[no-untyped # Read information about how to connect back to the JVM from the environment. java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] - (sock_file, _) = local_connect_and_auth(java_port, auth_secret) + (sock_file, sock) = local_connect_and_auth(java_port, auth_secret) + # There could be a long time between each listener event. + sock.settimeout(None) write_int(os.getpid(), sock_file) sock_file.flush() main(sock_file, sock_file) diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py index 547462d4da6d5..4bf58bf7807b3 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py @@ -60,6 +60,10 @@ def test_listener_events(self): try: self.spark.streams.addListener(test_listener) + # This ensures the read socket on the server won't crash (i.e. because of timeout) + # when there hasn't been a new event for a long time + time.sleep(30) + df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() q = df.writeStream.format("noop").queryName("test").start() @@ -76,6 +80,9 @@ def test_listener_events(self): finally: self.spark.streams.removeListener(test_listener) + # Remove again to verify this won't throw any error + self.spark.streams.removeListener(test_listener) + if __name__ == "__main__": import unittest From 380c0f2033fb83b5e4f13693d2576d72c5cc01f2 Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Fri, 4 Aug 2023 10:22:46 +0900 Subject: [PATCH 15/68] [SPARK-44640][PYTHON] Improve error messages for Python UDTF returning non Iterable ### What changes were proposed in this pull request? This PR improves error messages when the result of a Python UDTF is not an Iterable. It also improves the error messages when a UDTF encounters an exception when executing `eval`. ### Why are the changes needed? To make Python UDTFs more user-friendly. ### Does this PR introduce _any_ user-facing change? Yes. For example this UDTF: ``` udtf(returnType="x: int") class TestUDTF: def eval(self, a): return a ``` Before this PR, it fails with this error for regular UDTFs: ``` return tuple(map(verify_and_convert_result, res)) TypeError: 'int' object is not iterable ``` And this error for arrow-optimized UDTFs: ``` raise ValueError("DataFrame constructor not properly called!") ValueError: DataFrame constructor not properly called! ``` After this PR, the error message will be: `pyspark.errors.exceptions.base.PySparkRuntimeError: [UDTF_RETURN_NOT_ITERABLE] The return value of the UDTF is invalid. It should be an iterable (e.g., generator or list), but got 'int'. Please make sure that the UDTF returns one of these types.` ### How was this patch tested? New UTs. Closes #42302 from allisonwang-db/spark-44640-udtf-non-iterable. Authored-by: allisonwang-db Signed-off-by: Hyukjin Kwon --- python/pyspark/errors/error_classes.py | 5 +++ python/pyspark/sql/tests/test_udtf.py | 42 +++++++++++++++++++- python/pyspark/sql/udtf.py | 40 +++++++++++++++---- python/pyspark/worker.py | 53 ++++++++++++++------------ 4 files changed, 105 insertions(+), 35 deletions(-) diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index d6f093246dacd..84448f1507dd8 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -738,6 +738,11 @@ "User defined table function encountered an error in the '' method: " ] }, + "UDTF_RETURN_NOT_ITERABLE" : { + "message" : [ + "The return value of the UDTF is invalid. It should be an iterable (e.g., generator or list), but got ''. Please make sure that the UDTF returns one of these types." + ] + }, "UDTF_RETURN_SCHEMA_MISMATCH" : { "message" : [ "The number of columns in the result does not match the specified schema. Expected column count: , Actual column count: . Please make sure the values returned by the function have the same number of columns as specified in the output schema." diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 65184549573dc..26da83980e160 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -180,6 +180,15 @@ def eval(self, a: int): with self.assertRaisesRegex(PythonException, "Unexpected tuple 1 with StructType"): func(lit(1)).collect() + def test_udtf_with_invalid_return_value(self): + @udtf(returnType="x: int") + class TestUDTF: + def eval(self, a): + return a + + with self.assertRaisesRegex(PythonException, "UDTF_RETURN_NOT_ITERABLE"): + TestUDTF(lit(1)).collect() + def test_udtf_eval_with_no_return(self): @udtf(returnType="a: int") class TestUDTF: @@ -375,6 +384,35 @@ def terminate(self): ], ) + def test_init_with_exception(self): + @udtf(returnType="x: int") + class TestUDTF: + def __init__(self): + raise Exception("error") + + def eval(self): + yield 1, + + with self.assertRaisesRegex( + PythonException, + r"\[UDTF_EXEC_ERROR\] User defined table function encountered an error " + r"in the '__init__' method: error", + ): + TestUDTF().show() + + def test_eval_with_exception(self): + @udtf(returnType="x: int") + class TestUDTF: + def eval(self): + raise Exception("error") + + with self.assertRaisesRegex( + PythonException, + r"\[UDTF_EXEC_ERROR\] User defined table function encountered an error " + r"in the 'eval' method: error", + ): + TestUDTF().show() + def test_terminate_with_exceptions(self): @udtf(returnType="a: int, b: int") class TestUDTF: @@ -386,8 +424,8 @@ def terminate(self): with self.assertRaisesRegex( PythonException, - "User defined table function encountered an error in the 'terminate' " - "method: terminate error", + r"\[UDTF_EXEC_ERROR\] User defined table function encountered an error " + r"in the 'terminate' method: terminate error", ): TestUDTF(lit(1)).collect() diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py index d14a263f839c9..74a9084c6cd55 100644 --- a/python/pyspark/sql/udtf.py +++ b/python/pyspark/sql/udtf.py @@ -18,14 +18,15 @@ User-defined table function related classes and functions """ from dataclasses import dataclass +from functools import wraps import inspect import sys import warnings -from typing import Any, Iterator, Type, TYPE_CHECKING, Optional, Union +from typing import Any, Iterable, Iterator, Type, TYPE_CHECKING, Optional, Union, Callable from py4j.java_gateway import JavaObject -from pyspark.errors import PySparkAttributeError, PySparkTypeError +from pyspark.errors import PySparkAttributeError, PySparkRuntimeError, PySparkTypeError from pyspark.rdd import PythonEvalType from pyspark.sql.column import _to_java_column, _to_seq from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version @@ -143,6 +144,20 @@ def _vectorize_udtf(cls: Type) -> Type: """Vectorize a Python UDTF handler class.""" import pandas as pd + # Wrap the exception thrown from the UDTF in a PySparkRuntimeError. + def wrap_func(f: Callable[..., Any]) -> Callable[..., Any]: + @wraps(f) + def evaluate(*a: Any) -> Any: + try: + return f(*a) + except Exception as e: + raise PySparkRuntimeError( + error_class="UDTF_EXEC_ERROR", + message_parameters={"method_name": f.__name__, "error": str(e)}, + ) + + return evaluate + class VectorizedUDTF: def __init__(self) -> None: self.func = cls() @@ -157,17 +172,26 @@ def analyze(*args: AnalyzeArgument) -> AnalyzeResult: def eval(self, *args: pd.Series) -> Iterator[pd.DataFrame]: if len(args) == 0: - yield pd.DataFrame(self.func.eval()) + yield pd.DataFrame(wrap_func(self.func.eval)()) else: # Create tuples from the input pandas Series, each tuple # represents a row across all Series. row_tuples = zip(*args) for row in row_tuples: - yield pd.DataFrame(self.func.eval(*row)) - - def terminate(self) -> Iterator[pd.DataFrame]: - if hasattr(self.func, "terminate"): - yield pd.DataFrame(self.func.terminate()) + res = wrap_func(self.func.eval)(*row) + if res is not None and not isinstance(res, Iterable): + raise PySparkRuntimeError( + error_class="UDTF_RETURN_NOT_ITERABLE", + message_parameters={ + "type": type(res).__name__, + }, + ) + yield pd.DataFrame(res) + + if hasattr(cls, "terminate"): + + def terminate(self) -> Iterator[pd.DataFrame]: + yield pd.DataFrame(wrap_func(self.func.terminate)()) vectorized_udtf = VectorizedUDTF vectorized_udtf.__name__ = cls.__name__ diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 20e856c9addc3..3acfa58b6fb8b 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -23,7 +23,7 @@ import time from inspect import currentframe, getframeinfo, getfullargspec import json -from typing import Iterator +from typing import Iterable, Iterator # 'resource' is a Unix specific module. has_resource_module = True @@ -591,6 +591,7 @@ def read_udtf(pickleSer, infile, eval_type): def wrap_arrow_udtf(f, return_type): arrow_return_type = to_arrow_type(return_type) + return_type_size = len(return_type) def verify_result(result): import pandas as pd @@ -599,7 +600,7 @@ def verify_result(result): raise PySparkTypeError( error_class="INVALID_ARROW_UDTF_RETURN_TYPE", message_parameters={ - "type_name": type(result).__name_, + "type_name": type(result).__name__, "value": str(result), }, ) @@ -609,11 +610,11 @@ def verify_result(result): # result dataframe may contain an empty row. For example, when a UDTF is # defined as follows: def eval(self): yield tuple(). if len(result) > 0 or len(result.columns) > 0: - if len(result.columns) != len(return_type): + if len(result.columns) != return_type_size: raise PySparkRuntimeError( error_class="UDTF_RETURN_SCHEMA_MISMATCH", message_parameters={ - "expected": str(len(return_type)), + "expected": str(return_type_size), "actual": str(len(result.columns)), }, ) @@ -641,13 +642,7 @@ def mapper(_, it): yield from eval(*[a[o] for o in arg_offsets]) finally: if terminate is not None: - try: - yield from terminate() - except BaseException as e: - raise PySparkRuntimeError( - error_class="UDTF_EXEC_ERROR", - message_parameters={"method_name": "terminate", "error": str(e)}, - ) + yield from terminate() return mapper, None, ser, ser @@ -656,15 +651,16 @@ def mapper(_, it): def wrap_udtf(f, return_type): assert return_type.needConversion() toInternal = return_type.toInternal + return_type_size = len(return_type) def verify_and_convert_result(result): # TODO(SPARK-44005): support returning non-tuple values if result is not None and hasattr(result, "__len__"): - if len(result) != len(return_type): + if len(result) != return_type_size: raise PySparkRuntimeError( error_class="UDTF_RETURN_SCHEMA_MISMATCH", message_parameters={ - "expected": str(len(return_type)), + "expected": str(return_type_size), "actual": str(len(result)), }, ) @@ -672,16 +668,29 @@ def verify_and_convert_result(result): # Evaluate the function and return a tuple back to the executor. def evaluate(*a) -> tuple: - res = f(*a) + try: + res = f(*a) + except Exception as e: + raise PySparkRuntimeError( + error_class="UDTF_EXEC_ERROR", + message_parameters={"method_name": f.__name__, "error": str(e)}, + ) + if res is None: # If the function returns None or does not have an explicit return statement, # an empty tuple is returned to the executor. # This is because directly constructing tuple(None) results in an exception. return tuple() - else: - # If the function returns a result, we map it to the internal representation and - # returns the results as a tuple. - return tuple(map(verify_and_convert_result, res)) + + if not isinstance(res, Iterable): + raise PySparkRuntimeError( + error_class="UDTF_RETURN_NOT_ITERABLE", + message_parameters={"type": type(res).__name__}, + ) + + # If the function returns a result, we map it to the internal representation and + # returns the results as a tuple. + return tuple(map(verify_and_convert_result, res)) return evaluate @@ -699,13 +708,7 @@ def mapper(_, it): yield eval(*[a[o] for o in arg_offsets]) finally: if terminate is not None: - try: - yield terminate() - except BaseException as e: - raise PySparkRuntimeError( - error_class="UDTF_EXEC_ERROR", - message_parameters={"method_name": "terminate", "error": str(e)}, - ) + yield terminate() return mapper, None, ser, ser From 26ed4fbc00dd9331807f747dd4e8ed7993c2497f Mon Sep 17 00:00:00 2001 From: itholic Date: Fri, 4 Aug 2023 10:35:06 +0900 Subject: [PATCH 16/68] [SPARK-43873][PS] Enabling `FrameDescribeTests` ### What changes were proposed in this pull request? This PR proposes to enable the test `FrameDescribeTests`. ### Why are the changes needed? To increate test coverage for pandas API on Spark with pandas 2.0.0 and above. ### Does this PR introduce _any_ user-facing change? No, it's test-only. ### How was this patch tested? Enabling the existing test. Closes #42319 from itholic/pandas_describe. Authored-by: itholic Signed-off-by: Hyukjin Kwon --- .../pandas/tests/computation/test_describe.py | 39 +++++-------------- 1 file changed, 9 insertions(+), 30 deletions(-) diff --git a/python/pyspark/pandas/tests/computation/test_describe.py b/python/pyspark/pandas/tests/computation/test_describe.py index af98d2869da9b..bbee9654eae4b 100644 --- a/python/pyspark/pandas/tests/computation/test_describe.py +++ b/python/pyspark/pandas/tests/computation/test_describe.py @@ -39,10 +39,6 @@ def df_pair(self): psdf = ps.from_pandas(pdf) return pdf, psdf - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43556): Enable DataFrameSlowTests.test_describe for pandas 2.0.0.", - ) def test_describe(self): pdf, psdf = self.df_pair @@ -78,19 +74,10 @@ def test_describe(self): } ) pdf = psdf._to_pandas() - # NOTE: Set `datetime_is_numeric=True` for pandas: - # FutureWarning: Treating datetime data as categorical rather than numeric in - # `.describe` is deprecated and will be removed in a future version of pandas. - # Specify `datetime_is_numeric=True` to silence this - # warning and adopt the future behavior now. - # NOTE: Compare the result except percentiles, since we use approximate percentile - # so the result is different from pandas. if LooseVersion(pd.__version__) >= LooseVersion("1.1.0"): self.assert_eq( psdf.describe().loc[["count", "mean", "min", "max"]], - pdf.describe(datetime_is_numeric=True) - .astype(str) - .loc[["count", "mean", "min", "max"]], + pdf.describe().astype(str).loc[["count", "mean", "min", "max"]], ) else: self.assert_eq( @@ -136,17 +123,13 @@ def test_describe(self): if LooseVersion(pd.__version__) >= LooseVersion("1.1.0"): self.assert_eq( psdf.describe().loc[["count", "mean", "min", "max"]], - pdf.describe(datetime_is_numeric=True) - .astype(str) - .loc[["count", "mean", "min", "max"]], + pdf.describe().astype(str).loc[["count", "mean", "min", "max"]], ) psdf.A += psdf.A pdf.A += pdf.A self.assert_eq( psdf.describe().loc[["count", "mean", "min", "max"]], - pdf.describe(datetime_is_numeric=True) - .astype(str) - .loc[["count", "mean", "min", "max"]], + pdf.describe().astype(str).loc[["count", "mean", "min", "max"]], ) else: expected_result = ps.DataFrame( @@ -187,7 +170,7 @@ def test_describe(self): ) pdf = psdf._to_pandas() if LooseVersion(pd.__version__) >= LooseVersion("1.1.0"): - pandas_result = pdf.describe(datetime_is_numeric=True) + pandas_result = pdf.describe() pandas_result.B = pandas_result.B.astype(str) self.assert_eq( psdf.describe().loc[["count", "mean", "min", "max"]], @@ -195,7 +178,7 @@ def test_describe(self): ) psdf.A += psdf.A pdf.A += pdf.A - pandas_result = pdf.describe(datetime_is_numeric=True) + pandas_result = pdf.describe() pandas_result.B = pandas_result.B.astype(str) self.assert_eq( psdf.describe().loc[["count", "mean", "min", "max"]], @@ -252,7 +235,7 @@ def test_describe(self): ) pdf = psdf._to_pandas() if LooseVersion(pd.__version__) >= LooseVersion("1.1.0"): - pandas_result = pdf.describe(datetime_is_numeric=True) + pandas_result = pdf.describe() pandas_result.b = pandas_result.b.astype(str) self.assert_eq( psdf.describe().loc[["count", "mean", "min", "max"]], @@ -288,10 +271,6 @@ def test_describe(self): with self.assertRaisesRegex(ValueError, msg): psdf.describe() - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43556): Enable DataFrameSlowTests.test_describe for pandas 2.0.0.", - ) def test_describe_empty(self): # Empty DataFrame psdf = ps.DataFrame(columns=["A", "B"]) @@ -328,7 +307,7 @@ def test_describe_empty(self): # For timestamp type, we should convert NaT to None in pandas result # since pandas API on Spark doesn't support the NaT for object type. if LooseVersion(pd.__version__) >= LooseVersion("1.1.0"): - pdf_result = pdf[pdf.a != pdf.a].describe(datetime_is_numeric=True) + pdf_result = pdf[pdf.a != pdf.a].describe() self.assert_eq( psdf[psdf.a != psdf.a].describe(), pdf_result.where(pdf_result.notnull(), None).astype(str), @@ -367,7 +346,7 @@ def test_describe_empty(self): ) pdf = psdf._to_pandas() if LooseVersion(pd.__version__) >= LooseVersion("1.1.0"): - pdf_result = pdf[pdf.a != pdf.a].describe(datetime_is_numeric=True) + pdf_result = pdf[pdf.a != pdf.a].describe() pdf_result.b = pdf_result.b.where(pdf_result.b.notnull(), None).astype(str) self.assert_eq( psdf[psdf.a != psdf.a].describe(), @@ -417,7 +396,7 @@ def test_describe_empty(self): ) pdf = psdf._to_pandas() if LooseVersion(pd.__version__) >= LooseVersion("1.1.0"): - pdf_result = pdf[pdf.a != pdf.a].describe(datetime_is_numeric=True) + pdf_result = pdf[pdf.a != pdf.a].describe() self.assert_eq( psdf[psdf.a != psdf.a].describe(), pdf_result.where(pdf_result.notnull(), None).astype(str), From 678f47264e084af766ed339df21513f44d05897f Mon Sep 17 00:00:00 2001 From: itholic Date: Fri, 4 Aug 2023 10:36:04 +0900 Subject: [PATCH 17/68] [SPARK-43562][SPARK-43870][PS] Remove APIs from `DataFrame` and `Series` ### What changes were proposed in this pull request? This PR proposes to remove DataFrame/Series APIs that removed from [pandas 2](https://pandas.pydata.org/docs/dev/whatsnew/v2.0.0.html) and above. ### Why are the changes needed? To match the behavior to pandas. ### Does this PR introduce _any_ user-facing change? (DataFrame|Series).(iteritems|mad|append) will be removed. ### How was this patch tested? Enabling the existing tests. Closes #42268 from itholic/pandas_remove_df_api. Authored-by: itholic Signed-off-by: Hyukjin Kwon --- .../migration_guide/pyspark_upgrade.rst | 11 + .../source/reference/pyspark.pandas/frame.rst | 3 - .../reference/pyspark.pandas/groupby.rst | 1 - .../reference/pyspark.pandas/series.rst | 3 - python/pyspark/pandas/frame.py | 204 +----------------- python/pyspark/pandas/groupby.py | 81 ------- python/pyspark/pandas/namespace.py | 1 - python/pyspark/pandas/series.py | 112 +--------- .../pandas/tests/computation/test_combine.py | 71 ++---- .../pandas/tests/computation/test_compute.py | 34 --- .../pyspark/pandas/tests/groupby/test_stat.py | 7 - .../pandas/tests/indexes/test_indexing.py | 8 +- .../pandas/tests/series/test_compute.py | 18 +- .../pandas/tests/series/test_series.py | 8 +- .../pyspark/pandas/tests/series/test_stat.py | 35 --- 15 files changed, 41 insertions(+), 556 deletions(-) diff --git a/python/docs/source/migration_guide/pyspark_upgrade.rst b/python/docs/source/migration_guide/pyspark_upgrade.rst index 7513d64ef6c59..9bd879fb1a1a6 100644 --- a/python/docs/source/migration_guide/pyspark_upgrade.rst +++ b/python/docs/source/migration_guide/pyspark_upgrade.rst @@ -19,6 +19,17 @@ Upgrading PySpark ================== +Upgrading from PySpark 3.5 to 4.0 +--------------------------------- + +* In Spark 4.0, ``DataFrame.iteritems`` has been removed from pandas API on Spark, use ``DataFrame.items`` instead. +* In Spark 4.0, ``Series.iteritems`` has been removed from pandas API on Spark, use ``Series.items`` instead. +* In Spark 4.0, ``DataFrame.append`` has been removed from pandas API on Spark, use ``ps.concat`` instead. +* In Spark 4.0, ``Series.append`` has been removed from pandas API on Spark, use ``ps.concat`` instead. +* In Spark 4.0, ``DataFrame.mad`` has been removed from pandas API on Spark. +* In Spark 4.0, ``Series.mad`` has been removed from pandas API on Spark. + + Upgrading from PySpark 3.3 to 3.4 --------------------------------- diff --git a/python/docs/source/reference/pyspark.pandas/frame.rst b/python/docs/source/reference/pyspark.pandas/frame.rst index a8d114187b94b..5f839a803d78a 100644 --- a/python/docs/source/reference/pyspark.pandas/frame.rst +++ b/python/docs/source/reference/pyspark.pandas/frame.rst @@ -79,7 +79,6 @@ Indexing, iteration DataFrame.iloc DataFrame.insert DataFrame.items - DataFrame.iteritems DataFrame.iterrows DataFrame.itertuples DataFrame.keys @@ -155,7 +154,6 @@ Computations / Descriptive Stats DataFrame.ewm DataFrame.kurt DataFrame.kurtosis - DataFrame.mad DataFrame.max DataFrame.mean DataFrame.min @@ -252,7 +250,6 @@ Combining / joining / merging .. autosummary:: :toctree: api/ - DataFrame.append DataFrame.assign DataFrame.merge DataFrame.join diff --git a/python/docs/source/reference/pyspark.pandas/groupby.rst b/python/docs/source/reference/pyspark.pandas/groupby.rst index da1579fd72350..e71e81c56dd3e 100644 --- a/python/docs/source/reference/pyspark.pandas/groupby.rst +++ b/python/docs/source/reference/pyspark.pandas/groupby.rst @@ -68,7 +68,6 @@ Computations / Descriptive Stats GroupBy.filter GroupBy.first GroupBy.last - GroupBy.mad GroupBy.max GroupBy.mean GroupBy.median diff --git a/python/docs/source/reference/pyspark.pandas/series.rst b/python/docs/source/reference/pyspark.pandas/series.rst index a0119593f96ae..552acec096f69 100644 --- a/python/docs/source/reference/pyspark.pandas/series.rst +++ b/python/docs/source/reference/pyspark.pandas/series.rst @@ -70,7 +70,6 @@ Indexing, iteration Series.keys Series.pop Series.items - Series.iteritems Series.item Series.xs Series.get @@ -148,7 +147,6 @@ Computations / Descriptive Stats Series.ewm Series.filter Series.kurt - Series.mad Series.max Series.mean Series.min @@ -247,7 +245,6 @@ Combining / joining / merging .. autosummary:: :toctree: api/ - Series.append Series.compare Series.replace Series.update diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index d8a3f812c33ab..b960b3444e319 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -1880,11 +1880,9 @@ def items(self) -> Iterator[Tuple[Name, "Series"]]: polar bear 22000 koala marsupial 80000 - >>> for label, content in df.iteritems(): + >>> for label, content in df.items(): ... print('label:', label) ... print('content:', content.to_string()) - ... - ... # doctest: +SKIP label: species content: panda bear polar bear @@ -2057,20 +2055,6 @@ def extract_kv_from_spark_row(row: Row) -> Tuple[Name, Any]: ): yield tuple(([k] if index else []) + list(v)) - def iteritems(self) -> Iterator[Tuple[Name, "Series"]]: - """ - This is an alias of ``items``. - - .. deprecated:: 3.4.0 - iteritems is deprecated and will be removed in a future version. - Use .items instead. - """ - warnings.warn( - "Deprecated in 3.4.0, and will be removed in 4.0.0. Use DataFrame.items instead.", - FutureWarning, - ) - return self.items() - def to_clipboard(self, excel: bool = True, sep: Optional[str] = None, **kwargs: Any) -> None: """ Copy object to the system clipboard. @@ -8837,91 +8821,6 @@ def combine_first(self, other: "DataFrame") -> "DataFrame": ) return DataFrame(internal) - def append( - self, - other: "DataFrame", - ignore_index: bool = False, - verify_integrity: bool = False, - sort: bool = False, - ) -> "DataFrame": - """ - Append rows of other to the end of caller, returning a new object. - - Columns in other that are not in the caller are added as new columns. - - .. deprecated:: 3.4.0 - - Parameters - ---------- - other : DataFrame or Series/dict-like object, or list of these - The data to append. - - ignore_index : boolean, default False - If True, do not use the index labels. - - verify_integrity : boolean, default False - If True, raise ValueError on creating index with duplicates. - - sort : boolean, default False - Currently not supported. - - Returns - ------- - appended : DataFrame - - Examples - -------- - >>> df = ps.DataFrame([[1, 2], [3, 4]], columns=list('AB')) - - >>> df.append(df) - A B - 0 1 2 - 1 3 4 - 0 1 2 - 1 3 4 - - >>> df.append(df, ignore_index=True) - A B - 0 1 2 - 1 3 4 - 2 1 2 - 3 3 4 - """ - warnings.warn( - "The DataFrame.append method is deprecated " - "and will be removed in 4.0.0. " - "Use pyspark.pandas.concat instead.", - FutureWarning, - ) - if isinstance(other, ps.Series): - raise TypeError("DataFrames.append() does not support appending Series to DataFrames") - if sort: - raise NotImplementedError("The 'sort' parameter is currently not supported") - - if not ignore_index: - index_scols = self._internal.index_spark_columns - if len(index_scols) != other._internal.index_level: - raise ValueError("Both DataFrames have to have the same number of index levels") - - if ( - verify_integrity - and len(index_scols) > 0 - and ( - self._internal.spark_frame.select(index_scols) - .intersect( - other._internal.spark_frame.select(other._internal.index_spark_columns) - ) - .count() - ) - > 0 - ): - raise ValueError("Indices have overlapping values") - - # Lazy import to avoid circular dependency issues - from pyspark.pandas.namespace import concat - - return cast(DataFrame, concat([self, other], ignore_index=ignore_index)) - # TODO: add 'filter_func' and 'errors' parameter def update(self, other: "DataFrame", join: str = "left", overwrite: bool = True) -> None: """ @@ -12719,107 +12618,6 @@ def explode(self, column: Name, ignore_index: bool = False) -> "DataFrame": result_df: DataFrame = DataFrame(internal) return result_df.reset_index(drop=True) if ignore_index else result_df - def mad(self, axis: Axis = 0) -> "Series": - """ - Return the mean absolute deviation of values. - - .. deprecated:: 3.4.0 - - Parameters - ---------- - axis : {index (0), columns (1)} - Axis for the function to be applied on. - - Examples - -------- - >>> df = ps.DataFrame({'a': [1, 2, 3, np.nan], 'b': [0.1, 0.2, 0.3, np.nan]}, - ... columns=['a', 'b']) - - >>> df.mad() - a 0.666667 - b 0.066667 - dtype: float64 - - >>> df.mad(axis=1) # doctest: +SKIP - 0 0.45 - 1 0.90 - 2 1.35 - 3 NaN - dtype: float64 - """ - warnings.warn( - "The 'mad' method is deprecated and will be removed in 4.0.0. " - "To compute the same result, you may do `(df - df.mean()).abs().mean()`.", - FutureWarning, - ) - from pyspark.pandas.series import first_series - - axis = validate_axis(axis) - - if axis == 0: - - def get_spark_column(psdf: DataFrame, label: Label) -> PySparkColumn: - scol = psdf._internal.spark_column_for(label) - col_type = psdf._internal.spark_type_for(label) - - if isinstance(col_type, BooleanType): - scol = scol.cast("integer") - - return scol - - new_column_labels: List[Label] = [] - for label in self._internal.column_labels: - # Filtering out only columns of numeric and boolean type column. - dtype = self._psser_for(label).spark.data_type - if isinstance(dtype, (NumericType, BooleanType)): - new_column_labels.append(label) - - new_columns = [ - F.avg(get_spark_column(self, label)).alias(name_like_string(label)) - for label in new_column_labels - ] - - mean_data = self._internal.spark_frame.select(*new_columns).first() - - new_columns = [ - F.avg( - F.abs(get_spark_column(self, label) - mean_data[name_like_string(label)]) - ).alias(name_like_string(label)) - for label in new_column_labels - ] - - sdf = self._internal.spark_frame.select( - *[F.lit(None).cast(StringType()).alias(SPARK_DEFAULT_INDEX_NAME)], *new_columns - ) - - # The data is expected to be small so it's fine to transpose/use the default index. - with ps.option_context("compute.max_rows", 1): - internal = InternalFrame( - spark_frame=sdf, - index_spark_columns=[scol_for(sdf, SPARK_DEFAULT_INDEX_NAME)], - column_labels=new_column_labels, - column_label_names=self._internal.column_label_names, - ) - return first_series(DataFrame(internal).transpose()) - - else: - - @pandas_udf(returnType=DoubleType()) # type: ignore[call-overload] - def calculate_columns_axis(*cols: pd.Series) -> pd.Series: - return pd.concat(cols, axis=1).mad(axis=1) - - internal = self._internal.copy( - column_labels=[None], - data_spark_columns=[ - calculate_columns_axis(*self._internal.data_spark_columns).alias( - SPARK_DEFAULT_SERIES_NAME - ) - ], - data_fields=[None], - column_label_names=None, - ) - return first_series(DataFrame(internal)) - def mode(self, axis: Axis = 0, numeric_only: bool = False, dropna: bool = True) -> "DataFrame": """ Get the mode(s) of each element along the selected axis. diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py index 663a635668ebf..2de328177937f 100644 --- a/python/pyspark/pandas/groupby.py +++ b/python/pyspark/pandas/groupby.py @@ -991,87 +991,6 @@ def skew(self) -> FrameLike: bool_to_numeric=True, ) - # TODO: 'axis', 'skipna', 'level' parameter should be implemented. - def mad(self) -> FrameLike: - """ - Compute mean absolute deviation of groups, excluding missing values. - - .. versionadded:: 3.4.0 - - .. deprecated:: 3.4.0 - - Examples - -------- - >>> df = ps.DataFrame({"A": [1, 2, 1, 1], "B": [True, False, False, True], - ... "C": [3, 4, 3, 4], "D": ["a", "b", "b", "a"]}) - - >>> df.groupby("A").mad() - B C - A - 1 0.444444 0.444444 - 2 0.000000 0.000000 - - >>> df.B.groupby(df.A).mad() - A - 1 0.444444 - 2 0.000000 - Name: B, dtype: float64 - - See Also - -------- - pyspark.pandas.Series.groupby - pyspark.pandas.DataFrame.groupby - """ - warnings.warn( - "The 'mad' method is deprecated and will be removed in a future version. " - "To compute the same result, you may do `(group_df - group_df.mean()).abs().mean()`.", - FutureWarning, - ) - groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(self._groupkeys))] - internal, agg_columns, sdf = self._prepare_reduce( - groupkey_names=groupkey_names, - accepted_spark_types=(NumericType, BooleanType), - bool_to_numeric=False, - ) - psdf: DataFrame = DataFrame(internal) - - if len(psdf._internal.column_labels) > 0: - window = Window.partitionBy(groupkey_names).rowsBetween( - Window.unboundedPreceding, Window.unboundedFollowing - ) - new_agg_scols = {} - new_stat_scols = [] - for agg_column in agg_columns: - # it is not able to directly use 'self._reduce_for_stat_function', due to - # 'it is not allowed to use a window function inside an aggregate function'. - # so we need to create temporary columns to compute the 'abs(x - avg(x))' here. - agg_column_name = agg_column._internal.data_spark_column_names[0] - new_agg_column_name = verify_temp_column_name( - psdf._internal.spark_frame, "__tmp_agg_col_{}__".format(agg_column_name) - ) - casted_agg_scol = F.col(agg_column_name).cast("double") - new_agg_scols[new_agg_column_name] = F.abs( - casted_agg_scol - F.avg(casted_agg_scol).over(window) - ) - new_stat_scols.append(F.avg(F.col(new_agg_column_name)).alias(agg_column_name)) - - sdf = ( - psdf._internal.spark_frame.withColumns(new_agg_scols) - .groupby(groupkey_names) - .agg(*new_stat_scols) - ) - else: - sdf = sdf.select(*groupkey_names).distinct() - - internal = internal.copy( - spark_frame=sdf, - index_spark_columns=[scol_for(sdf, col) for col in groupkey_names], - data_spark_columns=[scol_for(sdf, col) for col in internal.data_spark_column_names], - data_fields=None, - ) - - return self._prepare_return(DataFrame(internal)) - def sem(self, ddof: int = 1) -> FrameLike: """ Compute standard error of the mean of groups, excluding missing values. diff --git a/python/pyspark/pandas/namespace.py b/python/pyspark/pandas/namespace.py index 3563a6d81b4fa..5ffec6bedb988 100644 --- a/python/pyspark/pandas/namespace.py +++ b/python/pyspark/pandas/namespace.py @@ -2365,7 +2365,6 @@ def concat( See Also -------- - Series.append : Concatenate Series. DataFrame.join : Join DataFrames using indexes. DataFrame.merge : Merge DataFrames by indexes or columns. diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py index 95ca92e78787d..9fbbadd5420a8 100644 --- a/python/pyspark/pandas/series.py +++ b/python/pyspark/pandas/series.py @@ -3584,71 +3584,6 @@ def nlargest(self, n: int = 5) -> "Series": """ return self.sort_values(ascending=False).head(n) - def append( - self, to_append: "Series", ignore_index: bool = False, verify_integrity: bool = False - ) -> "Series": - """ - Concatenate two or more Series. - - .. deprecated:: 3.4.0 - - Parameters - ---------- - to_append : Series or list/tuple of Series - ignore_index : boolean, default False - If True, do not use the index labels. - verify_integrity : boolean, default False - If True, raise Exception on creating index with duplicates - - Returns - ------- - appended : Series - - Examples - -------- - >>> s1 = ps.Series([1, 2, 3]) - >>> s2 = ps.Series([4, 5, 6]) - >>> s3 = ps.Series([4, 5, 6], index=[3,4,5]) - - >>> s1.append(s2) # doctest: +SKIP - 0 1 - 1 2 - 2 3 - 0 4 - 1 5 - 2 6 - dtype: int64 - - >>> s1.append(s3) # doctest: +SKIP - 0 1 - 1 2 - 2 3 - 3 4 - 4 5 - 5 6 - dtype: int64 - - With ignore_index set to True: - - >>> s1.append(s2, ignore_index=True) # doctest: +SKIP - 0 1 - 1 2 - 2 3 - 3 4 - 4 5 - 5 6 - dtype: int64 - """ - warnings.warn( - "The Series.append method is deprecated " - "and will be removed in 4.0.0. " - "Use pyspark.pandas.concat instead.", - FutureWarning, - ) - return first_series( - self.to_frame().append(to_append.to_frame(), ignore_index, verify_integrity) - ).rename(self.name) - def sample( self, n: Optional[int] = None, @@ -5939,37 +5874,6 @@ def asof(self, where: Union[Any, List]) -> Union[Scalar, "Series"]: pdf.columns = pd.Index(where) return first_series(DataFrame(pdf.transpose())).rename(self.name) - def mad(self) -> float: - """ - Return the mean absolute deviation of values. - - .. deprecated:: 3.4.0 - - Examples - -------- - >>> s = ps.Series([1, 2, 3, 4]) - >>> s - 0 1 - 1 2 - 2 3 - 3 4 - dtype: int64 - - >>> s.mad() - 1.0 - """ - warnings.warn( - "The 'mad' method is deprecated and will be removed in 4.0.0. " - "To compute the same result, you may do `(series - series.mean()).abs().mean()`.", - FutureWarning, - ) - sdf = self._internal.spark_frame - spark_column = self.spark.column - avg = unpack_scalar(sdf.select(F.avg(spark_column))) - mad = unpack_scalar(sdf.select(F.avg(F.abs(spark_column - avg)))) - - return mad - def unstack(self, level: int = -1) -> DataFrame: """ Unstack, a.k.a. pivot, Series with MultiIndex to produce DataFrame. @@ -6083,7 +5987,7 @@ def items(self) -> Iterable[Tuple[Name, Any]]: This method returns an iterable tuple (index, value). This is convenient if you want to create a lazy iterator. - .. note:: Unlike pandas', the iteritems in pandas-on-Spark returns generator rather + .. note:: Unlike pandas', the itmes in pandas-on-Spark returns generator rather zip object Returns @@ -6123,20 +6027,6 @@ def extract_kv_from_spark_row(row: Row) -> Tuple[Name, Any]: ): yield k, v - def iteritems(self) -> Iterable[Tuple[Name, Any]]: - """ - This is an alias of ``items``. - - .. deprecated:: 3.4.0 - iteritems is deprecated and will be removed in a future version. - Use .items instead. - """ - warnings.warn( - "Deprecated in 3.4, and will be removed in 4.0.0. Use Series.items instead.", - FutureWarning, - ) - return self.items() - def droplevel(self, level: Union[int, Name, List[Union[int, Name]]]) -> "Series": """ Return Series with requested index level(s) removed. diff --git a/python/pyspark/pandas/tests/computation/test_combine.py b/python/pyspark/pandas/tests/computation/test_combine.py index dd55c0fd68661..adba20b5d99b3 100644 --- a/python/pyspark/pandas/tests/computation/test_combine.py +++ b/python/pyspark/pandas/tests/computation/test_combine.py @@ -41,46 +41,26 @@ def df_pair(self): psdf = ps.from_pandas(pdf) return pdf, psdf - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43562): Enable DataFrameTests.test_append for pandas 2.0.0.", - ) - def test_append(self): + def test_concat(self): pdf = pd.DataFrame([[1, 2], [3, 4]], columns=list("AB")) psdf = ps.from_pandas(pdf) other_pdf = pd.DataFrame([[3, 4], [5, 6]], columns=list("BC"), index=[2, 3]) other_psdf = ps.from_pandas(other_pdf) - self.assert_eq(psdf.append(psdf), pdf.append(pdf)) - self.assert_eq(psdf.append(psdf, ignore_index=True), pdf.append(pdf, ignore_index=True)) + self.assert_eq(ps.concat([psdf, psdf]), pd.concat([pdf, pdf])) + self.assert_eq( + ps.concat([psdf, psdf], ignore_index=True), pd.concat([pdf, pdf], ignore_index=True) + ) # Assert DataFrames with non-matching columns - self.assert_eq(psdf.append(other_psdf), pdf.append(other_pdf)) - - # Assert appending a Series fails - msg = "DataFrames.append() does not support appending Series to DataFrames" - with self.assertRaises(TypeError, msg=msg): - psdf.append(psdf["A"]) - - # Assert using the sort parameter raises an exception - msg = "The 'sort' parameter is currently not supported" - with self.assertRaises(NotImplementedError, msg=msg): - psdf.append(psdf, sort=True) + self.assert_eq(ps.concat([psdf, other_psdf]), pd.concat([pdf, other_pdf])) - # Assert using 'verify_integrity' only raises an exception for overlapping indices - self.assert_eq( - psdf.append(other_psdf, verify_integrity=True), - pdf.append(other_pdf, verify_integrity=True), - ) - msg = "Indices have overlapping values" - with self.assertRaises(ValueError, msg=msg): - psdf.append(psdf, verify_integrity=True) + ps.concat([psdf, psdf["A"]]) + # Assert appending a Series + self.assert_eq(ps.concat([psdf, psdf["A"]]), pd.concat([pdf, pdf["A"]])) - # Skip integrity verification when ignore_index=True - self.assert_eq( - psdf.append(psdf, ignore_index=True, verify_integrity=True), - pdf.append(pdf, ignore_index=True, verify_integrity=True), - ) + # Assert using the sort parameter + self.assert_eq(ps.concat([psdf, psdf], sort=True), pd.concat([pdf, pdf], sort=True)) # Assert appending multi-index DataFrames multi_index_pdf = pd.DataFrame([[1, 2], [3, 4]], columns=list("AB"), index=[[2, 3], [4, 5]]) @@ -91,45 +71,32 @@ def test_append(self): other_multi_index_psdf = ps.from_pandas(other_multi_index_pdf) self.assert_eq( - multi_index_psdf.append(multi_index_psdf), multi_index_pdf.append(multi_index_pdf) + ps.concat([multi_index_psdf, multi_index_psdf]), + pd.concat([multi_index_pdf, multi_index_pdf]), ) # Assert DataFrames with non-matching columns self.assert_eq( - multi_index_psdf.append(other_multi_index_psdf), - multi_index_pdf.append(other_multi_index_pdf), - ) - - # Assert using 'verify_integrity' only raises an exception for overlapping indices - self.assert_eq( - multi_index_psdf.append(other_multi_index_psdf, verify_integrity=True), - multi_index_pdf.append(other_multi_index_pdf, verify_integrity=True), - ) - with self.assertRaises(ValueError, msg=msg): - multi_index_psdf.append(multi_index_psdf, verify_integrity=True) - - # Skip integrity verification when ignore_index=True - self.assert_eq( - multi_index_psdf.append(multi_index_psdf, ignore_index=True, verify_integrity=True), - multi_index_pdf.append(multi_index_pdf, ignore_index=True, verify_integrity=True), + ps.concat([multi_index_psdf, other_multi_index_psdf]), + pd.concat([multi_index_pdf, other_multi_index_pdf]), ) # Assert trying to append DataFrames with different index levels msg = "Both DataFrames have to have the same number of index levels" with self.assertRaises(ValueError, msg=msg): - psdf.append(multi_index_psdf) + ps.concat([psdf, multi_index_psdf]) # Skip index level check when ignore_index=True self.assert_eq( - psdf.append(multi_index_psdf, ignore_index=True), - pdf.append(multi_index_pdf, ignore_index=True), + ps.concat([psdf, other_multi_index_psdf], ignore_index=True), + pd.concat([pdf, other_multi_index_pdf], ignore_index=True), ) columns = pd.MultiIndex.from_tuples([("A", "X"), ("A", "Y")]) pdf.columns = columns psdf.columns = columns - self.assert_eq(psdf.append(psdf), pdf.append(pdf)) + self.assert_eq(ps.concat([psdf, psdf]), pd.concat([pdf, pdf])) def test_merge(self): left_pdf = pd.DataFrame( diff --git a/python/pyspark/pandas/tests/computation/test_compute.py b/python/pyspark/pandas/tests/computation/test_compute.py index 5ce273c1f4769..d4b49f2ac8b01 100644 --- a/python/pyspark/pandas/tests/computation/test_compute.py +++ b/python/pyspark/pandas/tests/computation/test_compute.py @@ -78,40 +78,6 @@ def test_clip(self): str_psdf = ps.DataFrame({"A": ["a", "b", "c"]}, index=np.random.rand(3)) self.assert_eq(str_psdf.clip(1, 3), str_psdf) - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43560): Enable DataFrameSlowTests.test_mad for pandas 2.0.0.", - ) - def test_mad(self): - pdf = pd.DataFrame( - { - "A": [1, 2, None, 4, np.nan], - "B": [-0.1, 0.2, -0.3, np.nan, 0.5], - "C": ["a", "b", "c", "d", "e"], - } - ) - psdf = ps.from_pandas(pdf) - - self.assert_eq(psdf.mad(), pdf.mad()) - self.assert_eq(psdf.mad(axis=1), pdf.mad(axis=1)) - - with self.assertRaises(ValueError): - psdf.mad(axis=2) - - # MultiIndex columns - columns = pd.MultiIndex.from_tuples([("A", "X"), ("A", "Y"), ("A", "Z")]) - pdf.columns = columns - psdf.columns = columns - - self.assert_eq(psdf.mad(), pdf.mad()) - self.assert_eq(psdf.mad(axis=1), pdf.mad(axis=1)) - - pdf = pd.DataFrame({"A": [True, True, False, False], "B": [True, False, False, True]}) - psdf = ps.from_pandas(pdf) - - self.assert_eq(psdf.mad(), pdf.mad()) - self.assert_eq(psdf.mad(axis=1), pdf.mad(axis=1)) - def test_mode(self): pdf = pd.DataFrame( { diff --git a/python/pyspark/pandas/tests/groupby/test_stat.py b/python/pyspark/pandas/tests/groupby/test_stat.py index bfdeeecce303c..8a5096942e689 100644 --- a/python/pyspark/pandas/tests/groupby/test_stat.py +++ b/python/pyspark/pandas/tests/groupby/test_stat.py @@ -206,13 +206,6 @@ def test_sum(self): psdf.groupby("A").sum(min_count=3).sort_index(), ) - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43553): Enable GroupByTests.test_mad for pandas 2.0.0.", - ) - def test_mad(self): - self._test_stat_func(lambda groupby_obj: groupby_obj.mad()) - def test_first(self): self._test_stat_func(lambda groupby_obj: groupby_obj.first()) self._test_stat_func(lambda groupby_obj: groupby_obj.first(numeric_only=None)) diff --git a/python/pyspark/pandas/tests/indexes/test_indexing.py b/python/pyspark/pandas/tests/indexes/test_indexing.py index 64fc75347baf3..111dd09696d79 100644 --- a/python/pyspark/pandas/tests/indexes/test_indexing.py +++ b/python/pyspark/pandas/tests/indexes/test_indexing.py @@ -53,11 +53,7 @@ def test_head(self): with option_context("compute.ordered_head", True): self.assert_eq(psdf.head(), pdf.head()) - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43559): Enable DataFrameSlowTests.test_iteritems for pandas 2.0.0.", - ) - def test_iteritems(self): + def test_items(self): pdf = pd.DataFrame( {"species": ["bear", "bear", "marsupial"], "population": [1864, 22000, 80000]}, index=["panda", "polar", "koala"], @@ -65,7 +61,7 @@ def test_iteritems(self): ) psdf = ps.from_pandas(pdf) - for (p_name, p_items), (k_name, k_items) in zip(pdf.iteritems(), psdf.iteritems()): + for (p_name, p_items), (k_name, k_items) in zip(pdf.items(), psdf.items()): self.assert_eq(p_name, k_name) self.assert_eq(p_items, k_items) diff --git a/python/pyspark/pandas/tests/series/test_compute.py b/python/pyspark/pandas/tests/series/test_compute.py index 2fbdaef865e50..7d39f0523d456 100644 --- a/python/pyspark/pandas/tests/series/test_compute.py +++ b/python/pyspark/pandas/tests/series/test_compute.py @@ -142,11 +142,7 @@ def test_compare(self): expected = ps.DataFrame([[1, 2], [2, 3]], index=["x", "y"], columns=["self", "other"]) self.assert_eq(expected, psser.compare(psser + 1).sort_index()) - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43465): Enable SeriesTests.test_append for pandas 2.0.0.", - ) - def test_append(self): + def test_concat(self): pser1 = pd.Series([1, 2, 3], name="0") pser2 = pd.Series([4, 5, 6], name="0") pser3 = pd.Series([4, 5, 6], index=[3, 4, 5], name="0") @@ -154,17 +150,13 @@ def test_append(self): psser2 = ps.from_pandas(pser2) psser3 = ps.from_pandas(pser3) - self.assert_eq(psser1.append(psser2), pser1.append(pser2)) - self.assert_eq(psser1.append(psser3), pser1.append(pser3)) + self.assert_eq(ps.concat([psser1, psser2]), pd.concat([pser1, pser2])) + self.assert_eq(ps.concat([psser1, psser3]), pd.concat([pser1, pser3])) self.assert_eq( - psser1.append(psser2, ignore_index=True), pser1.append(pser2, ignore_index=True) + ps.concat([psser1, psser2], ignore_index=True), + pd.concat([pser1, pser2], ignore_index=True), ) - psser1.append(psser3, verify_integrity=True) - msg = "Indices have overlapping values" - with self.assertRaises(ValueError, msg=msg): - psser1.append(psser2, verify_integrity=True) - def test_shift(self): pser = pd.Series([10, 20, 15, 30, 45], name="x") psser = ps.Series(pser) diff --git a/python/pyspark/pandas/tests/series/test_series.py b/python/pyspark/pandas/tests/series/test_series.py index 116acb2a5b2b3..f7f186b672452 100644 --- a/python/pyspark/pandas/tests/series/test_series.py +++ b/python/pyspark/pandas/tests/series/test_series.py @@ -670,15 +670,11 @@ def test_filter(self): with self.assertRaisesRegex(ValueError, "The item should not be empty."): psser.filter(items=[(), ("three", "z")]) - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43480): Enable SeriesTests.test_iteritems for pandas 2.0.0.", - ) - def test_iteritems(self): + def test_items(self): pser = pd.Series(["A", "B", "C"]) psser = ps.from_pandas(pser) - for (p_name, p_items), (k_name, k_items) in zip(pser.iteritems(), psser.iteritems()): + for (p_name, p_items), (k_name, k_items) in zip(pser.items(), psser.items()): self.assert_eq(p_name, k_name) self.assert_eq(p_items, k_items) diff --git a/python/pyspark/pandas/tests/series/test_stat.py b/python/pyspark/pandas/tests/series/test_stat.py index 0d6e242492149..048a4c94fd939 100644 --- a/python/pyspark/pandas/tests/series/test_stat.py +++ b/python/pyspark/pandas/tests/series/test_stat.py @@ -524,41 +524,6 @@ def test_div_zero_and_nan(self): self.assert_eq(pser // 0, psser // 0) self.assert_eq(pser.floordiv(np.nan), psser.floordiv(np.nan)) - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43468): Enable SeriesTests.test_mad for pandas 2.0.0.", - ) - def test_mad(self): - pser = pd.Series([1, 2, 3, 4], name="Koalas") - psser = ps.from_pandas(pser) - - self.assert_eq(pser.mad(), psser.mad()) - - pser = pd.Series([None, -2, 5, 10, 50, np.nan, -20], name="Koalas") - psser = ps.from_pandas(pser) - - self.assert_eq(pser.mad(), psser.mad()) - - pmidx = pd.MultiIndex.from_tuples( - [("a", "1"), ("a", "2"), ("b", "1"), ("b", "2"), ("c", "1")] - ) - pser = pd.Series([1, 2, 3, 4, 5], name="Koalas") - pser.index = pmidx - psser = ps.from_pandas(pser) - - self.assert_eq(pser.mad(), psser.mad()) - - pmidx = pd.MultiIndex.from_tuples( - [("a", "1"), ("a", "2"), ("b", "1"), ("b", "2"), ("c", "1")] - ) - pser = pd.Series([None, -2, 5, 50, np.nan], name="Koalas") - pser.index = pmidx - psser = ps.from_pandas(pser) - - # Mark almost as True to avoid precision issue like: - # "21.555555555555554 != 21.555555555555557" - self.assert_eq(pser.mad(), psser.mad(), almost=True) - @unittest.skipIf( LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), "TODO(SPARK-43481): Enable SeriesTests.test_product for pandas 2.0.0.", From 4ed59d932631ba819d6b6071ee91408622a312db Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 4 Aug 2023 09:51:31 +0800 Subject: [PATCH 18/68] [SPARK-44619][INFRA] Free up disk space for container jobs ### What changes were proposed in this pull request? Free up disk space for container jobs ### Why are the changes needed? increase the available disk space before this PR ![image](https://github.com/apache/spark/assets/7322292/64230324-607b-4c1d-ac2d-84b9bcaab12a) after this PR ![image](https://github.com/apache/spark/assets/7322292/aafed2d6-5d26-4f7f-b020-1efe4f551a8f) ### Does this PR introduce _any_ user-facing change? No, infra-only ### How was this patch tested? updated CI Closes #42253 from zhengruifeng/infra_clean_container. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- .github/workflows/build_and_test.yml | 6 +++++ dev/free_disk_space_container | 33 ++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+) create mode 100755 dev/free_disk_space_container diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index d9bcdfcbfa474..ea0c8e1d7fdeb 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -410,6 +410,8 @@ jobs: key: pyspark-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }} restore-keys: | pyspark-coursier- + - name: Free up disk space + run: ./dev/free_disk_space_container - name: Install Java ${{ matrix.java }} uses: actions/setup-java@v3 with: @@ -508,6 +510,8 @@ jobs: key: sparkr-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }} restore-keys: | sparkr-coursier- + - name: Free up disk space + run: ./dev/free_disk_space_container - name: Install Java ${{ inputs.java }} uses: actions/setup-java@v3 with: @@ -616,6 +620,8 @@ jobs: key: docs-maven-${{ hashFiles('**/pom.xml') }} restore-keys: | docs-maven- + - name: Free up disk space + run: ./dev/free_disk_space_container - name: Install Java 8 uses: actions/setup-java@v3 with: diff --git a/dev/free_disk_space_container b/dev/free_disk_space_container new file mode 100755 index 0000000000000..cc3b74643e4fa --- /dev/null +++ b/dev/free_disk_space_container @@ -0,0 +1,33 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +echo "==================================" +echo "Free up disk space on CI system" +echo "==================================" + +echo "Listing 100 largest packages" +dpkg-query -Wf '${Installed-Size}\t${Package}\n' | sort -n | tail -n 100 +df -h + +echo "Removing large packages" +rm -rf /__t/CodeQL +rm -rf /__t/go +rm -rf /__t/node + +df -h From 492f6fac02a00b9ad545d84fa3f10a021a8e71b9 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Fri, 4 Aug 2023 12:03:54 +0900 Subject: [PATCH 19/68] [SPARK-44664][PYTHON][CONNECT] Release the execute when closing the iterator in Python client ### What changes were proposed in this pull request? This PR implements the symmetry of https://github.com/apache/spark/pull/42331 and https://github.com/apache/spark/pull/42304 1. It releases the execute when the error is raised during the iteration 2. When you explicitly close the generator, (e.g., either `generator.close()` or explicit `GeneratorExit`), it releases the execution. ### Why are the changes needed? For the feature parity, see also https://github.com/apache/spark/pull/42331 and https://github.com/apache/spark/pull/42304 ### Does this PR introduce _any_ user-facing change? See also https://github.com/apache/spark/pull/42331 and https://github.com/apache/spark/pull/42304 ### How was this patch tested? Tests will be added separately. Closes #42330 from HyukjinKwon/python-error-release. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/connect/client/reattach.py | 110 ++++++++++++------ 1 file changed, 72 insertions(+), 38 deletions(-) diff --git a/python/pyspark/sql/connect/client/reattach.py b/python/pyspark/sql/connect/client/reattach.py index 4d4cce0ca4413..702107d97f549 100644 --- a/python/pyspark/sql/connect/client/reattach.py +++ b/python/pyspark/sql/connect/client/reattach.py @@ -111,9 +111,9 @@ def send(self, value: Any) -> pb2.ExecutePlanResponse: self._last_returned_response_id = ret.response_id if ret.HasField("result_complete"): self._result_complete = True - self._release_execute(None) # release all + self._release_all() else: - self._release_execute(self._last_returned_response_id) + self._release_until(self._last_returned_response_id) self._current = None return ret @@ -125,61 +125,93 @@ def _has_next(self) -> bool: # After response complete response return False else: - for attempt in Retrying( - can_retry=SparkConnectClient.retry_exception, **self._retry_policy - ): - with attempt: - # on first try, we use the existing iterator. - if not attempt.is_first_try(): - # on retry, the iterator is borked, so we need a new one - self._iterator = iter( - self._stub.ReattachExecute(self._create_reattach_execute_request()) - ) - - if self._current is None: - try: - self._current = next(self._iterator) - except StopIteration: - pass - - has_next = self._current is not None - - # Graceful reattach: - # If iterator ended, but there was no ResponseComplete, it means that - # there is more, and we need to reattach. While ResponseComplete didn't - # arrive, we keep reattaching. - if not self._result_complete and not has_next: - while not has_next: + try: + for attempt in Retrying( + can_retry=SparkConnectClient.retry_exception, **self._retry_policy + ): + with attempt: + # on first try, we use the existing iterator. + if not attempt.is_first_try(): + # on retry, the iterator is borked, so we need a new one self._iterator = iter( self._stub.ReattachExecute(self._create_reattach_execute_request()) ) - # shouldn't change - assert not self._result_complete + + if self._current is None: try: self._current = next(self._iterator) except StopIteration: pass - has_next = self._current is not None - return has_next + + has_next = self._current is not None + + # Graceful reattach: + # If iterator ended, but there was no ResponseComplete, it means that + # there is more, and we need to reattach. While ResponseComplete didn't + # arrive, we keep reattaching. + if not self._result_complete and not has_next: + while not has_next: + self._iterator = iter( + self._stub.ReattachExecute( + self._create_reattach_execute_request() + ) + ) + # shouldn't change + assert not self._result_complete + try: + self._current = next(self._iterator) + except StopIteration: + pass + has_next = self._current is not None + return has_next + except Exception as e: + self._release_all() + raise e return False - def _release_execute(self, until_response_id: Optional[str]) -> None: + def _release_until(self, until_response_id: str) -> None: """ - Inform the server to release the execution. + Inform the server to release the buffered execution results until and including given + result. This will send an asynchronous RPC which will not block this iterator, the iterator can continue to be consumed. + """ + if self._result_complete: + return + + from pyspark.sql.connect.client.core import SparkConnectClient + from pyspark.sql.connect.client.core import Retrying - Release with untilResponseId informs the server that the iterator has been consumed until - and including response with that responseId, and these responses can be freed. + request = self._create_release_execute_request(until_response_id) - Release with None means that the responses have been completely consumed and informs the - server that the completed execution can be completely freed. + def target() -> None: + try: + for attempt in Retrying( + can_retry=SparkConnectClient.retry_exception, **self._retry_policy + ): + with attempt: + self._stub.ReleaseExecute(request) + except Exception as e: + warnings.warn(f"ReleaseExecute failed with exception: {e}.") + + ExecutePlanResponseReattachableIterator._release_thread_pool.apply_async(target) + + def _release_all(self) -> None: + """ + Inform the server to release the execution, either because all results were consumed, + or the execution finished with error and the error was received. + + This will send an asynchronous RPC which will not block this. The client continues + executing, and if the release fails, server is equipped to deal with abandoned executions. """ + if self._result_complete: + return + from pyspark.sql.connect.client.core import SparkConnectClient from pyspark.sql.connect.client.core import Retrying - request = self._create_release_execute_request(until_response_id) + request = self._create_release_execute_request(None) def target() -> None: try: @@ -192,6 +224,7 @@ def target() -> None: warnings.warn(f"ReleaseExecute failed with exception: {e}.") ExecutePlanResponseReattachableIterator._release_thread_pool.apply_async(target) + self._result_complete = True def _create_reattach_execute_request(self) -> pb2.ReattachExecuteRequest: reattach = pb2.ReattachExecuteRequest( @@ -231,6 +264,7 @@ def throw(self, type: Any = None, value: Any = None, traceback: Any = None) -> A super().throw(type, value, traceback) def close(self) -> None: + self._release_all() return super().close() def __del__(self) -> None: From 52437bc73695e392bee60fbb340b6de4324b25d8 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Fri, 4 Aug 2023 12:05:19 +0900 Subject: [PATCH 20/68] [SPARK-44624][CONNECT] Retry ExecutePlan in case initial request didn't reach server ### What changes were proposed in this pull request? If the ExecutePlan never reached the server, a ReattachExecute will fail with INVALID_HANDLE.OPERATION_NOT_FOUND. In that case, we could try to send ExecutePlan again. ### Why are the changes needed? This solves an edge case of reattachable execution where the initial execution never reached the server. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Testing these failures is difficult, will require some special testing setup Closes #42282 from juliuszsompolski/SPARK-44624-fix. Authored-by: Juliusz Sompolski Signed-off-by: Hyukjin Kwon --- ...cutePlanResponseReattachableIterator.scala | 43 +++++++++++++++---- .../sql/connect/client/GrpcRetryHandler.scala | 10 ++++- 2 files changed, 43 insertions(+), 10 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala index fc07deaa081f8..41648c3c10048 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala @@ -20,7 +20,8 @@ import java.util.UUID import scala.util.control.NonFatal -import io.grpc.ManagedChannel +import io.grpc.{ManagedChannel, StatusRuntimeException} +import io.grpc.protobuf.StatusProto import io.grpc.stub.StreamObserver import org.apache.spark.connect.proto @@ -38,15 +39,12 @@ import org.apache.spark.internal.Logging * Initial iterator is the result of an ExecutePlan on the request, but it can be reattached with * ReattachExecute request. ReattachExecute request is provided the responseId of last returned * ExecutePlanResponse on the iterator to return a new iterator from server that continues after - * that. + * that. If the initial ExecutePlan did not even reach the server, and hence reattach fails with + * INVALID_HANDLE.OPERATION_NOT_FOUND, we attempt to retry ExecutePlan. * * In reattachable execute the server does buffer some responses in case the client needs to * backtrack. To let server release this buffer sooner, this iterator asynchronously sends * ReleaseExecute RPCs that instruct the server to release responses that it already processed. - * - * Note: If the initial ExecutePlan did not even reach the server and execution didn't start, the - * ReattachExecute can still fail with INVALID_HANDLE.OPERATION_NOT_FOUND, failing the whole - * operation. */ class ExecutePlanResponseReattachableIterator( request: proto.ExecutePlanRequest, @@ -113,7 +111,7 @@ class ExecutePlanResponseReattachableIterator( // on retry, the iterator is borked, so we need a new one iterator = rawBlockingStub.reattachExecute(createReattachExecuteRequest()) } - iterator.next() + callIter(_.next()) } // Record last returned response, to know where to restart in case of reattach. @@ -146,7 +144,7 @@ class ExecutePlanResponseReattachableIterator( // on retry, the iterator is borked, so we need a new one iterator = rawBlockingStub.reattachExecute(createReattachExecuteRequest()) } - var hasNext = iterator.hasNext() + var hasNext = callIter(_.hasNext()) // Graceful reattach: // If iterator ended, but there was no ResultComplete, it means that there is more, // and we need to reattach. @@ -154,7 +152,7 @@ class ExecutePlanResponseReattachableIterator( do { iterator = rawBlockingStub.reattachExecute(createReattachExecuteRequest()) assert(!resultComplete) // shouldn't change... - hasNext = iterator.hasNext() + hasNext = callIter(_.hasNext()) // It's possible that the new iterator will be empty, so we need to loop to get another. // Eventually, there will be a non empty iterator, because there is always a // ResultComplete inserted by the server at the end of the stream. @@ -197,6 +195,33 @@ class ExecutePlanResponseReattachableIterator( } } + /** + * Call next() or hasNext() on the iterator. If this fails with this operationId not existing on + * the server, this means that the initial ExecutePlan request didn't even reach the server. In + * that case, attempt to start again with ExecutePlan. + * + * Called inside retry block, so retryable failure will get handled upstream. + */ + private def callIter[V](iterFun: java.util.Iterator[proto.ExecutePlanResponse] => V) = { + try { + iterFun(iterator) + } catch { + case ex: StatusRuntimeException + if StatusProto + .fromThrowable(ex) + .getMessage + .contains("INVALID_HANDLE.OPERATION_NOT_FOUND") => + if (lastReturnedResponseId.isDefined) { + throw new IllegalStateException( + "OPERATION_NOT_FOUND on the server but responses were already received from it.", + ex) + } + // Try a new ExecutePlan, and throw upstream for retry. + iterator = rawBlockingStub.executePlan(initialRequest) + throw new GrpcRetryHandler.RetryException + } + } + /** * Create result callback to the asynchronouse ReleaseExecute. The client does not block on * ReleaseExecute and continues with iteration, but if it fails with a retryable error, the diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala index ef446399f1674..47ff975b26756 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala @@ -164,7 +164,9 @@ private[client] object GrpcRetryHandler extends Logging { try { return fn } catch { - case NonFatal(e) if retryPolicy.canRetry(e) && currentRetryNum < retryPolicy.maxRetries => + case NonFatal(e) + if (retryPolicy.canRetry(e) || e.isInstanceOf[RetryException]) + && currentRetryNum < retryPolicy.maxRetries => logWarning( s"Non fatal error during RPC execution: $e, " + s"retrying (currentRetryNum=$currentRetryNum)") @@ -209,4 +211,10 @@ private[client] object GrpcRetryHandler extends Logging { maxBackoff: FiniteDuration = FiniteDuration(1, "min"), backoffMultiplier: Double = 4.0, canRetry: Throwable => Boolean = retryException) {} + + /** + * An exception that can be thrown upstream when inside retry and which will be retryable + * regardless of policy. + */ + class RetryException extends Throwable } From ce1fe57cdd7004a891ef8b97c77ac96b3719efcd Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 4 Aug 2023 11:26:36 +0800 Subject: [PATCH 21/68] [SPARK-44653][SQL] Non-trivial DataFrame unions should not break caching ### What changes were proposed in this pull request? We have a long-standing tricky optimization in `Dataset.union`, which invokes the optimizer rule `CombineUnions` to pre-optimize the analyzed plan. This is to avoid too large analyzed plan for a specific dataframe query pattern `df1.union(df2).union(df3).union...`. This tricky optimization is designed to break dataframe caching, but we thought it was fine as people usually won't cache the intermediate dataframe in a union chain. However, `CombineUnions` gets improved from time to time (e.g. https://github.com/apache/spark/pull/35214) and now it can optimize a wide range of Union patterns. Now it's possible that people union two dataframe, do something with `select`, and cache it. Then the dataframe is unioned again with other dataframes and people expect the df cache to work. However the cache won't work due to the tricky optimization in `Dataset.union`. This PR updates `Dataset.union` to only combine adjacent Unions to match the original purpose. ### Why are the changes needed? Fix perf regression due to breaking df caching ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? new test Closes #42315 from cloud-fan/union. Lead-authored-by: Wenchen Fan Co-authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../scala/org/apache/spark/sql/Dataset.scala | 56 ++++++++++++++++--- .../apache/spark/sql/DatasetCacheSuite.scala | 21 +++++++ 3 files changed, 70 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9fc664bb1c26d..f83cd36f0a82b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -157,7 +157,7 @@ abstract class Optimizer(catalogManager: CatalogManager) // since the other rules might make two separate Unions operators adjacent. Batch("Inline CTE", Once, InlineCTE()) :: - Batch("Union", Once, + Batch("Union", fixedPoint, RemoveNoopOperators, CombineUnions, RemoveNoopUnion) :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 7b2259a6d9945..61c83829d2012 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -42,11 +42,10 @@ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions} -import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.catalyst.trees.{TreeNodeTag, TreePattern} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, IntervalUtils} import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId @@ -2241,6 +2240,51 @@ class Dataset[T] private[sql]( Offset(Literal(n), logicalPlan) } + // This breaks caching, but it's usually ok because it addresses a very specific use case: + // using union to union many files or partitions. + private def combineUnions(plan: LogicalPlan): LogicalPlan = { + plan.transformDownWithPruning(_.containsPattern(TreePattern.UNION)) { + case Distinct(u: Union) => + Distinct(flattenUnion(u, isUnionDistinct = true)) + // Only handle distinct-like 'Deduplicate', where the keys == output + case Deduplicate(keys: Seq[Attribute], u: Union) if AttributeSet(keys) == u.outputSet => + Deduplicate(keys, flattenUnion(u, true)) + case u: Union => + flattenUnion(u, isUnionDistinct = false) + } + } + + private def flattenUnion(u: Union, isUnionDistinct: Boolean): Union = { + var changed = false + // We only need to look at the direct children of Union, as the nested adjacent Unions should + // have been combined already by previous `Dataset#union` transformations. + val newChildren = u.children.flatMap { + case Distinct(Union(children, byName, allowMissingCol)) + if isUnionDistinct && byName == u.byName && allowMissingCol == u.allowMissingCol => + changed = true + children + // Only handle distinct-like 'Deduplicate', where the keys == output + case Deduplicate(keys: Seq[Attribute], child @ Union(children, byName, allowMissingCol)) + if AttributeSet(keys) == child.outputSet && isUnionDistinct && byName == u.byName && + allowMissingCol == u.allowMissingCol => + changed = true + children + case Union(children, byName, allowMissingCol) + if !isUnionDistinct && byName == u.byName && allowMissingCol == u.allowMissingCol => + changed = true + children + case other => + Seq(other) + } + if (changed) { + val newUnion = Union(newChildren) + newUnion.copyTagsFrom(u) + newUnion + } else { + u + } + } + /** * Returns a new Dataset containing union of rows in this Dataset and another Dataset. * @@ -2272,9 +2316,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def union(other: Dataset[T]): Dataset[T] = withSetOperator { - // This breaks caching, but it's usually ok because it addresses a very specific use case: - // using union to union many files or partitions. - CombineUnions(Union(logicalPlan, other.logicalPlan)) + combineUnions(Union(logicalPlan, other.logicalPlan)) } /** @@ -2366,9 +2408,7 @@ class Dataset[T] private[sql]( * @since 3.1.0 */ def unionByName(other: Dataset[T], allowMissingColumns: Boolean): Dataset[T] = withSetOperator { - // This breaks caching, but it's usually ok because it addresses a very specific use case: - // using union to union many files or partitions. - CombineUnions(Union(logicalPlan :: other.logicalPlan :: Nil, true, allowMissingColumns)) + combineUnions(Union(logicalPlan :: other.logicalPlan :: Nil, true, allowMissingColumns)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index 6033b9fee848e..a657c6212aa07 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -273,4 +273,25 @@ class DatasetCacheSuite extends QueryTest } } } + + test("SPARK-44653: non-trivial DataFrame unions should not break caching") { + val df1 = Seq(1 -> 1).toDF("i", "j") + val df2 = Seq(2 -> 2).toDF("i", "j") + val df3 = Seq(3 -> 3).toDF("i", "j") + + withClue("positive") { + val unionDf = df1.union(df2).select($"i") + unionDf.cache() + val finalDf = unionDf.union(df3.select($"i")) + assert(finalDf.queryExecution.executedPlan.exists(_.isInstanceOf[InMemoryTableScanExec])) + } + + withClue("negative") { + val unionDf = df1.union(df2) + unionDf.cache() + val finalDf = unionDf.union(df3) + // It's by design to break caching here. + assert(!finalDf.queryExecution.executedPlan.exists(_.isInstanceOf[InMemoryTableScanExec])) + } + } } From 2c2d6534bebed3c7bfa0842b84aa27674b721410 Mon Sep 17 00:00:00 2001 From: Kun Wan Date: Fri, 4 Aug 2023 14:24:02 +0900 Subject: [PATCH 22/68] [SPARK-44582][SQL] Skip iterator on SMJ if it was cleaned up ### What changes were proposed in this pull request? Bugfix for SMJ which may cause JVM crash. **When will the JVM crash** ``` Query pattern: TableScan TableScan | | Exchange Exchange | | Sort 1 Sort 2 | | Window 1 Window 2 \ / \ / SMJ | | WriteFileCommand ``` 1. WriteFileCommand call hasNext() to check if the input is empty. 2. SMJ call findNextJoinRows() to find all matched rows. 2.1 SMJ tries to get the first row in the left child. 2.1.1 Sort 1 will sort all the input rows in the Offheap memory. 2.1.2 Window 1 will read one group data and the first row in next group (named X), return the first row in the first group. 2.2 SMJ tries to get the first row in the right child. 2.2.1 Sort 2 and Window 2 are empty, do nothing. 2.3 Inner SMJ will finish, since there will definitely be no join rows, call earlyCleanupResources() to free offHeap memory. 3. WriteFileCommand call hasNext() again to write the input data to the files. 4. SMJ call findNextJoinRows() to find all matched rows. 4.1 SMJ tries to get the first row in the left child. 4.2 Window 1 tries to add row X into the group buffer, which will accesse unallocated memory, the JVM may or may not crash. In this PR, if SMJ has already been cleaned up, skip iterator on it. ### Why are the changes needed? Bugfix for SMJ. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Test in our production environment. For unsafe API, when read the unallocated memory, the program may get the old value, or get a unexpected value, or cause the JVM crash. I don't think the UIT will be stable. The JVM crash stack ``` Stack: [0x00007f8a03800000,0x00007f8a04000000], sp=0x00007f8a03ffd620, free space=8181k Native frames: (J=compiled Java code, j=interpreted, Vv=VM code, C=native code) v ~StubRoutines::jint_disjoint_arraycopy J 36127 C2 org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray.add(Lorg/apache/spark/sql/catalyst/expressions/UnsafeRow;)V (188 bytes) 0x00007f966187ac9f [0x00007f966187a820+0x47f] J 36146 C2 org.apache.spark.sql.execution.window.WindowExec$$anon$1.next()Ljava/lang/Object; (5 bytes) 0x00007f9661a8eefc [0x00007f9661a8dd60+0x119c] J 36153 C2 org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage4.processNext()V (381 bytes) 0x00007f966180185c [0x00007f9661801760+0xfc] J 36246 C2 org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage7.smj_findNextJoinRows_0$(Lorg/apache/spark/sql/catalyst/expressions/GeneratedClass$GeneratedIteratorForCodegenStage7;Lscala/collection/Iterator;Lscala/collection/Iterator;)Z (392 bytes) 0x00007f96607388f0 [0x00007f96607381e0+0x710] J 36249 C1 org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage7.processNext()V (109 bytes) 0x00007f965fa8ee64 [0x00007f965fa8e560+0x904] J 35645 C2 org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$2.hasNext()Z (31 bytes) 0x00007f965fbc58e4 [0x00007f965fbc58a0+0x44] j org.apache.spark.sql.execution.datasources.FileFormatWriter$.$anonfun$executeTask$1(Lscala/collection/Iterator;Lorg/apache/spark/sql/execution/datasources/FileFormatDataWriter;)Lorg/apache/spark/sql/execution/datasources/WriteTaskResult;+1 j org.apache.spark.sql.execution.datasources.FileFormatWriter$$$Lambda$4398.apply()Ljava/lang/Object;+8 j org.apache.spark.util.Utils$.tryWithSafeFinallyAndFailureCallbacks(Lscala/Function0;Lscala/Function0;Lscala/Function0;)Ljava/lang/Object;+4 j org.apache.spark.sql.execution.datasources.FileFormatWriter$.executeTask(Lorg/apache/spark/sql/execution/datasources/WriteJobDescription;JIIILorg/apache/spark/internal/io/FileCommitProtocol;ILscala/collection/Iterator;)Lorg/apache/spark/sql/execution/datasources/WriteTaskResult;+258 J 30523 C1 org.apache.spark.sql.execution.datasources.FileFormatWriter$.$anonfun$write$23(Lorg/apache/spark/sql/execution/datasources/WriteJobDescription;JLorg/apache/spark/internal/io/FileCommitProtocol;Lscala/runtime/IntRef;Lscala/collection/immutable/Map;Lorg/apache/spark/TaskContext;Lscala/collection/Iterator;)Lorg/apache/spark/sql/execution/datasources/WriteTaskResult; (61 bytes) 0x00007f966066b004 [0x00007f966066a7a0+0x864] J 30529 C1 org.apache.spark.sql.execution.datasources.FileFormatWriter$$$Lambda$3569.apply(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; (32 bytes) 0x00007f965f79bd1c [0x00007f965f79baa0+0x27c] J 29322 C1 org.apache.spark.scheduler.ResultTask.runTask(Lorg/apache/spark/TaskContext;)Ljava/lang/Object; (210 bytes) 0x00007f966094bd0c [0x00007f96609497a0+0x256c] J 24071 C1 org.apache.spark.scheduler.Task.run(JILorg/apache/spark/metrics/MetricsSystem;Lscala/collection/immutable/Map;)Ljava/lang/Object; (536 bytes) 0x00007f965fca493c [0x00007f965fca1000+0x393c] J 23198 C1 org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Lorg/apache/spark/executor/Executor$TaskRunner;Lscala/runtime/BooleanRef;)Ljava/lang/Object; (43 bytes) 0x00007f965f86373c [0x00007f965f8634e0+0x25c] J 23196 C1 org.apache.spark.executor.Executor$TaskRunner$$Lambda$984.apply()Ljava/lang/Object; (12 bytes) 0x00007f965f860e44 [0x00007f965f860dc0+0x84] ``` Closes #42206 from wankunde/smj_cleanup. Authored-by: Kun Wan Signed-off-by: Hyukjin Kwon --- .../execution/joins/SortMergeJoinExec.scala | 24 +++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 0241f683d6902..8d49b1558d687 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -556,14 +556,18 @@ case class SortMergeJoinExec( val doJoin = joinType match { case _: InnerLike => + val cleanedFlag = + ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "cleanedFlag", v => s"$v = false;") codegenInner(findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck, outputRow, - eagerCleanup) + eagerCleanup, cleanedFlag) case LeftOuter | RightOuter => codegenOuter(streamedInput, findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck, ctx.freshName("hasOutputRow"), outputRow, eagerCleanup) case LeftSemi => + val cleanedFlag = + ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "cleanedFlag", v => s"$v = false;") codegenSemi(findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck, - ctx.freshName("hasOutputRow"), outputRow, eagerCleanup) + ctx.freshName("hasOutputRow"), outputRow, eagerCleanup, cleanedFlag) case LeftAnti => codegenAnti(streamedInput, findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck, loadStreamed, ctx.freshName("hasMatchedRow"), outputRow, eagerCleanup) @@ -606,8 +610,13 @@ case class SortMergeJoinExec( bufferedRow: String, conditionCheck: String, outputRow: String, - eagerCleanup: String): String = { + eagerCleanup: String, + cleanedFlag: String): String = { s""" + |if($cleanedFlag) { + | return; + |} + | |while ($findNextJoinRows) { | $beforeLoop | while ($matchIterator.hasNext()) { @@ -617,6 +626,7 @@ case class SortMergeJoinExec( | } | if (shouldStop()) return; |} + |$cleanedFlag = true; |$eagerCleanup """.stripMargin } @@ -665,8 +675,13 @@ case class SortMergeJoinExec( conditionCheck: String, hasOutputRow: String, outputRow: String, - eagerCleanup: String): String = { + eagerCleanup: String, + cleanedFlag: String): String = { s""" + |if($cleanedFlag) { + | return; + |} + | |while ($findNextJoinRows) { | $beforeLoop | boolean $hasOutputRow = false; @@ -679,6 +694,7 @@ case class SortMergeJoinExec( | } | if (shouldStop()) return; |} + |$cleanedFlag = true; |$eagerCleanup """.stripMargin } From bd9dd3887eaf8e80a7084774fa3e893f2b91f659 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 4 Aug 2023 14:58:04 +0800 Subject: [PATCH 23/68] [SPARK-43967][SQL][PYTHON] Add memory limits for Python UDTF analyzer ### What changes were proposed in this pull request? Adds memory limits for Python UDTF analyzer. - `spark.sql.analyzer.pythonUDTF.analyzeInPython.memory` (`None` by default) > The amount of memory to be allocated to PySpark for Python UDTF analyzer, in MiB unless otherwise specified. If set, PySpark memory for Python UDTF analyzer will be limited to this amount. If not set, Spark will not limit Python's memory use and it is up to the application to avoid exceeding the overhead memory space shared with other non-JVM processes. Note: Windows does not support resource limiting and actual resource is not limited on MacOS. ### Why are the changes needed? Python UDTF analyzer should be able to set a memory limit. ### Does this PR introduce _any_ user-facing change? Users will be able to set the memory limit for Python UDTF analyzer. ### How was this patch tested? Existing tests. Closes #42328 from ueshin/issues/SPARK-44648/memory_limits. Authored-by: Takuya UESHIN Signed-off-by: Ruifeng Zheng --- python/pyspark/sql/worker/analyze_udtf.py | 5 ++ python/pyspark/worker.py | 42 ++--------------- python/pyspark/worker_util.py | 47 +++++++++++++++++++ .../apache/spark/sql/internal/SQLConf.scala | 14 ++++++ .../python/UserDefinedPythonFunction.scala | 4 ++ 5 files changed, 73 insertions(+), 39 deletions(-) diff --git a/python/pyspark/sql/worker/analyze_udtf.py b/python/pyspark/sql/worker/analyze_udtf.py index 44dcd8c892c8e..9ffa03541e695 100644 --- a/python/pyspark/sql/worker/analyze_udtf.py +++ b/python/pyspark/sql/worker/analyze_udtf.py @@ -40,6 +40,7 @@ pickleSer, send_accumulator_updates, setup_broadcasts, + setup_memory_limits, setup_spark_files, utf8_deserializer, ) @@ -96,6 +97,10 @@ def main(infile: IO, outfile: IO) -> None: """ try: check_python_version(infile) + + memory_limit_mb = int(os.environ.get("PYSPARK_UDTF_ANALYZER_MEMORY_MB", "-1")) + setup_memory_limits(memory_limit_mb) + setup_spark_files(infile) setup_broadcasts(infile) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 3acfa58b6fb8b..b32e20e3b0418 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -21,18 +21,11 @@ import os import sys import time -from inspect import currentframe, getframeinfo, getfullargspec +from inspect import getfullargspec import json from typing import Iterable, Iterator -# 'resource' is a Unix specific module. -has_resource_module = True -try: - import resource -except ImportError: - has_resource_module = False import traceback -import warnings import faulthandler from pyspark.accumulators import _accumulatorRegistry @@ -70,6 +63,7 @@ pickleSer, send_accumulator_updates, setup_broadcasts, + setup_memory_limits, setup_spark_files, utf8_deserializer, ) @@ -998,38 +992,8 @@ def main(infile, outfile): boundPort = read_int(infile) secret = UTF8Deserializer().loads(infile) - # set up memory limits memory_limit_mb = int(os.environ.get("PYSPARK_EXECUTOR_MEMORY_MB", "-1")) - if memory_limit_mb > 0 and has_resource_module: - total_memory = resource.RLIMIT_AS - try: - (soft_limit, hard_limit) = resource.getrlimit(total_memory) - msg = "Current mem limits: {0} of max {1}\n".format(soft_limit, hard_limit) - print(msg, file=sys.stderr) - - # convert to bytes - new_limit = memory_limit_mb * 1024 * 1024 - - if soft_limit == resource.RLIM_INFINITY or new_limit < soft_limit: - msg = "Setting mem limits to {0} of max {1}\n".format(new_limit, new_limit) - print(msg, file=sys.stderr) - resource.setrlimit(total_memory, (new_limit, new_limit)) - - except (resource.error, OSError, ValueError) as e: - # not all systems support resource limits, so warn instead of failing - lineno = ( - getframeinfo(currentframe()).lineno + 1 if currentframe() is not None else 0 - ) - if "__file__" in globals(): - print( - warnings.formatwarning( - "Failed to set memory limit: {0}".format(e), - ResourceWarning, - __file__, - lineno, - ), - file=sys.stderr, - ) + setup_memory_limits(memory_limit_mb) # initialize global state taskContext = None diff --git a/python/pyspark/worker_util.py b/python/pyspark/worker_util.py index eab0daf8f592b..9f6d46c6211d5 100644 --- a/python/pyspark/worker_util.py +++ b/python/pyspark/worker_util.py @@ -19,9 +19,18 @@ Util functions for workers. """ import importlib +from inspect import currentframe, getframeinfo import os import sys from typing import Any, IO +import warnings + +# 'resource' is a Unix specific module. +has_resource_module = True +try: + import resource +except ImportError: + has_resource_module = False from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry @@ -71,6 +80,44 @@ def check_python_version(infile: IO) -> None: ) +def setup_memory_limits(memory_limit_mb: int) -> None: + """ + Sets up the memory limits. + + If memory_limit_mb > 0 and `resource` module is available, sets the memory limit. + Windows does not support resource limiting and actual resource is not limited on MacOS. + """ + if memory_limit_mb > 0 and has_resource_module: + total_memory = resource.RLIMIT_AS + try: + (soft_limit, hard_limit) = resource.getrlimit(total_memory) + msg = "Current mem limits: {0} of max {1}\n".format(soft_limit, hard_limit) + print(msg, file=sys.stderr) + + # convert to bytes + new_limit = memory_limit_mb * 1024 * 1024 + + if soft_limit == resource.RLIM_INFINITY or new_limit < soft_limit: + msg = "Setting mem limits to {0} of max {1}\n".format(new_limit, new_limit) + print(msg, file=sys.stderr) + resource.setrlimit(total_memory, (new_limit, new_limit)) + + except (resource.error, OSError, ValueError) as e: + # not all systems support resource limits, so warn instead of failing + curent = currentframe() + lineno = getframeinfo(curent).lineno + 1 if curent is not None else 0 + if "__file__" in globals(): + print( + warnings.formatwarning( + "Failed to set memory limit: {0}".format(e), + ResourceWarning, + __file__, + lineno, + ), + file=sys.stderr, + ) + + def setup_spark_files(infile: IO) -> None: """ Set up Spark files, archives, and pyfiles. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index dfa2a0f251fea..ad2d323140a6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2944,6 +2944,18 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PYTHON_TABLE_UDF_ANALYZER_MEMORY = + buildConf("spark.sql.analyzer.pythonUDTF.analyzeInPython.memory") + .doc("The amount of memory to be allocated to PySpark for Python UDTF analyzer, in MiB " + + "unless otherwise specified. If set, PySpark memory for Python UDTF analyzer will be " + + "limited to this amount. If not set, Spark will not limit Python's " + + "memory use and it is up to the application to avoid exceeding the overhead memory space " + + "shared with other non-JVM processes.\nNote: Windows does not support resource limiting " + + "and actual resource is not limited on MacOS.") + .version("4.0.0") + .bytesConf(ByteUnit.MiB) + .createOptional + val PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME = buildConf("spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName") .internal() @@ -5012,6 +5024,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def pysparkWorkerPythonExecutable: Option[String] = getConf(SQLConf.PYSPARK_WORKER_PYTHON_EXECUTABLE) + def pythonUDTFAnalyzerMemory: Option[Long] = getConf(PYTHON_TABLE_UDF_ANALYZER_MEMORY) + def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index 05239d8d16462..36cb2e17835a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -175,6 +175,7 @@ object UserDefinedPythonTableFunction { val reuseWorker = env.conf.get(PYTHON_WORKER_REUSE) val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback + val workerMemoryMb = SQLConf.get.pythonUDTFAnalyzerMemory val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) @@ -192,6 +193,9 @@ object UserDefinedPythonTableFunction { if (simplifiedTraceback) { envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1") } + workerMemoryMb.foreach { memoryMb => + envVars.put("PYSPARK_UDTF_ANALYZER_MEMORY_MB", memoryMb.toString) + } envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString) envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString) From e4bae48d5aa38f98bf9f62724a2ce8111ab2ca5e Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 4 Aug 2023 15:35:29 +0800 Subject: [PATCH 24/68] [SPARK-44618][INFRA] Uninstall CodeQL/Go/Node in non-container jobs ### What changes were proposed in this pull request? Uninstall CodeQL/Go/Node in non-container jobs ### Why are the changes needed? it can save 10G disk space before this PR: ![image](https://github.com/apache/spark/assets/7322292/dcd45849-4849-4e95-ae76-f5b7c80b19d3) after this PR: ![image](https://github.com/apache/spark/assets/7322292/042bda6d-43d6-42ea-9f53-abd57766ba99) ### Does this PR introduce _any_ user-facing change? no, infra-only ### How was this patch tested? updated CI Closes #42333 from zhengruifeng/infra_uninstall_codeql. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- dev/free_disk_space | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dev/free_disk_space b/dev/free_disk_space index 87a09a524f4fd..2b2b20f814e02 100755 --- a/dev/free_disk_space +++ b/dev/free_disk_space @@ -34,7 +34,11 @@ sudo rm -rf /usr/local/share/powershell sudo rm -rf /usr/local/share/chromium sudo rm -rf /usr/local/lib/android sudo rm -rf /usr/local/lib/node_modules + sudo rm -rf /opt/az +sudo rm -rf /opt/hostedtoolcache/CodeQL +sudo rm -rf /opt/hostedtoolcache/go +sudo rm -rf /opt/hostedtoolcache/node sudo apt-get remove --purge -y '^aspnet.*' sudo apt-get remove --purge -y '^dotnet-.*' From e019e8720fb5495c990735976ed4b50c3a006804 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Fri, 4 Aug 2023 17:18:24 +0800 Subject: [PATCH 25/68] [SPARK-44600][INFRA] Make `repl` module to pass Maven daily testing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? The following situation exists in the Spark code: 1. The `repl` module depends on the `core` module, which in turn depends on the `common/network-common` module. 2. The `common/network-common` module will perform `shade+relocation` operation on the guava dependency (dependency scope is compile), while `core` will only perform `relocation` operation on the Guava dependency (dependency scope is provided) So when we conduct Maven tests on the `repl` module, it is necessary to depend on both the jar package of `core` module and `common/network-common` module or both `target/classes` directory of `core` module and `common/network-common` module, and when the Maven test command leads to the dependence on the jar package of `core` module and `target/classes` directory of `common/network-common` module, test failures will occure due to `shade+relocation` issue. For example: if we run `build/mvn test -pl common/network-common,repl` The test `Dependencies classpath` is `/Users/yangjie01/.m2/repository/org/apache/spark/spark-core_2.12/4.0.0-SNAPSHOT/spark-core_2.12-4.0.0-SNAPSHOT.jar:.../Users/yangjie01/SourceCode/git/spark-mine-12/common/network-common/target/scala-2.12/classes:...` `repl` module test failed as follows: ``` *** RUN ABORTED *** java.lang.NoClassDefFoundError: org/sparkproject/guava/cache/CacheLoader at org.apache.spark.SparkConf.loadFromSystemProperties(SparkConf.scala:75) at org.apache.spark.SparkConf.(SparkConf.scala:70) at org.apache.spark.SparkConf.(SparkConf.scala:59) at org.apache.spark.repl.Main$.(Main.scala:37) at org.apache.spark.repl.Main$.(Main.scala) at org.apache.spark.repl.ReplSuite.$anonfun$new$1(ReplSuite.scala:94) at org.scalatest.enablers.Timed$$anon$1.timeoutAfter(Timed.scala:127) at org.scalatest.concurrent.TimeLimits$.failAfterImpl(TimeLimits.scala:282) at org.scalatest.concurrent.TimeLimits.failAfter(TimeLimits.scala:231) at org.scalatest.concurrent.TimeLimits.failAfter$(TimeLimits.scala:230) ... Cause: java.lang.ClassNotFoundException: org.sparkproject.guava.cache.CacheLoader at java.net.URLClassLoader.findClass(URLClassLoader.java:387) at java.lang.ClassLoader.loadClass(ClassLoader.java:419) at sun.misc.Launcher$AppClassLoader.loadClass(Launcher.java:352) at java.lang.ClassLoader.loadClass(ClassLoader.java:352) at org.apache.spark.SparkConf.loadFromSystemProperties(SparkConf.scala:75) at org.apache.spark.SparkConf.(SparkConf.scala:70) at org.apache.spark.SparkConf.(SparkConf.scala:59) at org.apache.spark.repl.Main$.(Main.scala:37) at org.apache.spark.repl.Main$.(Main.scala) at org.apache.spark.repl.ReplSuite.$anonfun$new$1(ReplSuite.scala:94) ``` the test failed due to `core.jar` already relocation Guava class path, but the content in `network-common/target/scala-2.12/classes ` has not yet undergone `shaded + relocation` for Guava, the is determined by the lifecycle executed by Maven. But when we execute `build/mvn clean install -pl common/network-common,repl` The test `Dependencies classpath` is `/Users/yangjie01/.m2/repository/org/apache/spark/spark-core_2.12/4.0.0-SNAPSHOT/spark-core_2.12-4.0.0-SNAPSHOT.jar:.../Users/yangjie01/SourceCode/git/spark-mine-12/common/network-common/target/spark-network-common_2.12-4.0.0-SNAPSHOT.jar:...` And All tests passed. The failure of the `repl` module test in the Maven daily test is due to similar reason: https://github.com/apache/spark/actions/runs/5751080986/job/15589117861 image The possible solutions are as follows: 1. Force the use of `network-common.jar` as a test dependency in the above scenario during Maven testing, but I haven’t found a solution that can be confirmed as viable (consulted gpt-4) 2. Make the `core` module also perform shading on Guava, but this would increase the size of the core.jar by 15+% (14660516 bytes -> 16871122 bytes) 3. Move the `common/network-common` module to a separate group for testing to avoid similar problems, but this would waste some GA resources. 4. Move the `repl` module to another group to avoid this issue, but the remaining modules in the original group need to ensure that they can pass the test. Ultimately, this PR chose method 4, moving the `repl` module and `hive-thriftserver` into the same group, while also verifying through GAthat the `repl` module and the original group can both pass the test. ### Why are the changes needed? Make `repl` module to pass Maven daily testing. After this PR, only the `connector/connect/client/jvm` and `connector/connect/server` modules will have Maven test failures. ### Does this PR introduce _any_ user-facing change? No, just for test. ### How was this patch tested? - Verified and passed using GitHub Action. https://github.com/LuciferYang/spark/actions/runs/5745978299/job/15580477796 image Closes #42291 from LuciferYang/maven-repl. Authored-by: yangjie01 Signed-off-by: yangjie01 --- .github/workflows/maven_test.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/maven_test.yml b/.github/workflows/maven_test.yml index 48a4d6b5ff990..618ab69ba5998 100644 --- a/.github/workflows/maven_test.yml +++ b/.github/workflows/maven_test.yml @@ -57,11 +57,11 @@ jobs: - hive2.3 modules: - >- - core,repl,launcher,common#unsafe,common#kvstore,common#network-common,common#network-shuffle,common#sketch + core,launcher,common#unsafe,common#kvstore,common#network-common,common#network-shuffle,common#sketch - >- graphx,streaming,mllib-local,mllib,hadoop-cloud - >- - sql#hive-thriftserver + repl,sql#hive-thriftserver - >- connector#kafka-0-10,connector#kafka-0-10-sql,connector#kafka-0-10-token-provider,connector#spark-ganglia-lgpl,connector#protobuf,connector#avro - >- @@ -187,9 +187,9 @@ jobs: ./build/mvn $MAVEN_CLI_OPTS -pl "$TEST_MODULES" -Pyarn -Pmesos -Pkubernetes -Pvolcano -Phive -Phive-thriftserver -Phadoop-cloud -Pspark-ganglia-lgpl -Djava.version=${JAVA_VERSION/-ea} -Dtest.exclude.tags="$EXCLUDED_TAGS" test -fae elif [[ "$MODULES_TO_TEST" == "connect" ]]; then ./build/mvn $MAVEN_CLI_OPTS -Djava.version=${JAVA_VERSION/-ea} -pl connector/connect/client/jvm,connector/connect/common,connector/connect/server test -fae - elif [[ "$MODULES_TO_TEST" == "sql#hive-thriftserver" ]]; then + elif [[ "$MODULES_TO_TEST" == *"sql#hive-thriftserver"* ]]; then # To avoid a compilation loop, for the `sql/hive-thriftserver` module, run `clean install` instead - ./build/mvn $MAVEN_CLI_OPTS -pl sql/hive-thriftserver -Phive -Phive-thriftserver -Djava.version=${JAVA_VERSION/-ea} clean install -fae + ./build/mvn $MAVEN_CLI_OPTS -pl "$TEST_MODULES" -Pyarn -Pmesos -Pkubernetes -Pvolcano -Phive -Phive-thriftserver -Phadoop-cloud -Pspark-ganglia-lgpl -Djava.version=${JAVA_VERSION/-ea} clean install -fae else ./build/mvn $MAVEN_CLI_OPTS -pl "$TEST_MODULES" -Pyarn -Pmesos -Pkubernetes -Pvolcano -Phive -Phive-thriftserver -Pspark-ganglia-lgpl -Phadoop-cloud -Djava.version=${JAVA_VERSION/-ea} test -fae fi From a6f048a3d99305ac69755609101b3e7128eabdfe Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 4 Aug 2023 19:23:48 +0800 Subject: [PATCH 26/68] [SPARK-44655][SQL] Make the code cleaner about static and dynamic data/partition filters ### What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/41088 to make the code cleaner. There are two kinds of data/partition filters: static filters that can be executed during the planning phase, and dynamic filters that can only be executed after the planning phase. This PR makes sure these two kinds of filters are properly used and adds code comments to explain it. ### Why are the changes needed? code cleanup ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing tests Closes #42318 from cloud-fan/minor. Authored-by: Wenchen Fan Signed-off-by: Kent Yao --- .../sql/execution/DataSourceScanExec.scala | 68 +++---- .../sql/execution/datasources/FileIndex.scala | 55 +++++ .../datasources/FileSourceStrategy.scala | 6 +- .../PartitioningAwareFileIndex.scala | 50 +---- .../sql-tests/results/explain-aqe.sql.out | 4 +- .../sql-tests/results/explain.sql.out | 4 +- .../q14b.sf100/explain.txt | 4 +- .../approved-plans-v1_4/q14b/explain.txt | 4 +- .../approved-plans-v1_4/q54.sf100/explain.txt | 12 +- .../approved-plans-v1_4/q54/explain.txt | 12 +- .../approved-plans-v1_4/q58.sf100/explain.txt | 2 +- .../approved-plans-v1_4/q58/explain.txt | 2 +- .../approved-plans-v1_4/q6.sf100/explain.txt | 2 +- .../approved-plans-v1_4/q6/explain.txt | 2 +- .../approved-plans-v1_4/q64/explain.txt | 192 +++++++++--------- .../approved-plans-v2_7/q14.sf100/explain.txt | 4 +- .../approved-plans-v2_7/q14/explain.txt | 4 +- .../approved-plans-v2_7/q6.sf100/explain.txt | 2 +- .../approved-plans-v2_7/q6/explain.txt | 2 +- 19 files changed, 221 insertions(+), 210 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index a739fa40c71cb..e5a38967dc3e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -249,12 +249,18 @@ trait FileSourceScanLike extends DataSourceScanExec { private def isDynamicPruningFilter(e: Expression): Boolean = e.exists(_.isInstanceOf[PlanExpression[_]]) + + // This field will be accessed during planning (e.g., `outputPartitioning` relies on it), and can + // only use static filters. @transient lazy val selectedPartitions: Array[PartitionDirectory] = { val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) val startTime = System.nanoTime() - val ret = - relation.location.listFiles( - partitionFilters.filterNot(isDynamicPruningFilter), dataFilters) + // The filters may contain subquery expressions which can't be evaluated during planning. + // Here we filter out subquery expressions and get the static data/partition filters, so that + // they can be used to do pruning at the planning phase. + val staticDataFilters = dataFilters.filterNot(isDynamicPruningFilter) + val staticPartitionFilters = partitionFilters.filterNot(isDynamicPruningFilter) + val ret = relation.location.listFiles(staticPartitionFilters, staticDataFilters) setFilesNumAndSizeMetric(ret, true) val timeTakenMs = NANOSECONDS.toMillis( (System.nanoTime() - startTime) + optimizerMetadataTimeNs) @@ -266,6 +272,7 @@ trait FileSourceScanLike extends DataSourceScanExec { // present. This is because such a filter relies on information that is only available at run // time (for instance the keys used in the other side of a join). @transient protected lazy val dynamicallySelectedPartitions: Array[PartitionDirectory] = { + val dynamicDataFilters = dataFilters.filter(isDynamicPruningFilter) val dynamicPartitionFilters = partitionFilters.filter(isDynamicPruningFilter) if (dynamicPartitionFilters.nonEmpty) { @@ -278,7 +285,11 @@ trait FileSourceScanLike extends DataSourceScanExec { val index = partitionColumns.indexWhere(a.name == _.name) BoundReference(index, partitionColumns(index).dataType, nullable = true) }, Nil) - val ret = selectedPartitions.filter(p => boundPredicate.eval(p.values)) + var ret = selectedPartitions.filter(p => boundPredicate.eval(p.values)) + if (dynamicDataFilters.nonEmpty) { + val filePruningRunner = new FilePruningRunner(dynamicDataFilters) + ret = ret.map(filePruningRunner.prune) + } setFilesNumAndSizeMetric(ret, false) val timeTakenMs = (System.nanoTime() - startTime) / 1000 / 1000 driverMetrics("pruningTime").set(timeTakenMs) @@ -288,14 +299,6 @@ trait FileSourceScanLike extends DataSourceScanExec { } } - /** - * [[partitionFilters]] can contain subqueries whose results are available only at runtime so - * accessing [[selectedPartitions]] should be guarded by this method during planning - */ - private def hasPartitionsAvailableAtRunTime: Boolean = { - partitionFilters.exists(ExecSubqueryExpression.hasSubquery) - } - private def toAttribute(colName: String): Option[Attribute] = output.find(_.name == colName) @@ -339,8 +342,7 @@ trait FileSourceScanLike extends DataSourceScanExec { spec.sortColumnNames.map(x => toAttribute(x)).takeWhile(x => x.isDefined).map(_.get) val shouldCalculateSortOrder = conf.getConf(SQLConf.LEGACY_BUCKETED_TABLE_SCAN_OUTPUT_ORDERING) && - sortColumns.nonEmpty && - !hasPartitionsAvailableAtRunTime + sortColumns.nonEmpty val sortOrder = if (shouldCalculateSortOrder) { // In case of bucketing, its possible to have multiple files belonging to the @@ -371,35 +373,29 @@ trait FileSourceScanLike extends DataSourceScanExec { } } - private def translatePushedDownFilters(dataFilters: Seq[Expression]): Seq[Filter] = { + private def translateToV1Filters( + dataFilters: Seq[Expression], + scalarSubqueryToLiteral: execution.ScalarSubquery => Literal): Seq[Filter] = { + val scalarSubqueryReplaced = dataFilters.map(_.transform { + // Replace scalar subquery to literal so that `DataSourceStrategy.translateFilter` can + // support translating it. + case scalarSubquery: execution.ScalarSubquery => scalarSubqueryToLiteral(scalarSubquery) + }) + val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation) // `dataFilters` should not include any constant metadata col filters // because the metadata struct has been flatted in FileSourceStrategy // and thus metadata col filters are invalid to be pushed down. Metadata that is generated // during the scan can be used for filters. - dataFilters.filterNot(_.references.exists { + scalarSubqueryReplaced.filterNot(_.references.exists { case FileSourceConstantMetadataAttribute(_) => true case _ => false }).flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) } + // This field may execute subquery expressions and should not be accessed during planning. @transient - protected lazy val pushedDownFilters: Seq[Filter] = translatePushedDownFilters(dataFilters) - - @transient - protected lazy val dynamicallyPushedDownFilters: Seq[Filter] = { - if (dataFilters.exists(_.exists(_.isInstanceOf[execution.ScalarSubquery]))) { - // Replace scalar subquery to literal so that `DataSourceStrategy.translateFilter` can - // support translate it. The subquery must has been materialized since SparkPlan always - // execute subquery first. - val normalized = dataFilters.map(_.transform { - case scalarSubquery: execution.ScalarSubquery => scalarSubquery.toLiteral - }) - translatePushedDownFilters(normalized) - } else { - pushedDownFilters - } - } + protected lazy val pushedDownFilters: Seq[Filter] = translateToV1Filters(dataFilters, _.toLiteral) override lazy val metadata: Map[String, String] = { def seqToString(seq: Seq[Any]) = seq.mkString("[", ", ", "]") @@ -407,13 +403,17 @@ trait FileSourceScanLike extends DataSourceScanExec { val locationDesc = location.getClass.getSimpleName + Utils.buildLocationMetadata(location.rootPaths, maxMetadataValueLength) + // `metadata` is accessed during planning and the scalar subquery is not executed yet. Here + // we get the pretty string of the scalar subquery, for display purpose only. + val pushedFiltersForDisplay = translateToV1Filters( + dataFilters, s => Literal("ScalarSubquery#" + s.exprId.id)) val metadata = Map( "Format" -> relation.fileFormat.toString, "ReadSchema" -> requiredSchema.catalogString, "Batched" -> supportsColumnar.toString, "PartitionFilters" -> seqToString(partitionFilters), - "PushedFilters" -> seqToString(pushedDownFilters), + "PushedFilters" -> seqToString(pushedFiltersForDisplay), "DataFilters" -> seqToString(dataFilters), "Location" -> locationDesc) @@ -561,7 +561,7 @@ case class FileSourceScanExec( dataSchema = relation.dataSchema, partitionSchema = relation.partitionSchema, requiredSchema = requiredSchema, - filters = dynamicallyPushedDownFilters, + filters = pushedDownFilters, options = options, hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala index 1b28294e94a88..2535440add19a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql.execution.datasources +import scala.collection.mutable + import org.apache.hadoop.fs._ +import org.apache.spark.paths.SparkPath import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.StructType @@ -43,6 +46,58 @@ case class FileStatusWithMetadata(fileStatus: FileStatus, metadata: Map[String, */ case class PartitionDirectory(values: InternalRow, files: Seq[FileStatusWithMetadata]) +/** + * A runner that extracts file metadata filters from the given `filters` and use it to prune files + * in `PartitionDirectory`. + */ +class FilePruningRunner(filters: Seq[Expression]) { + // retrieve the file constant metadata filters and reduce to a final filter expression that can + // be applied to files. + val fileMetadataFilterOpt = filters.filter { f => + f.references.nonEmpty && f.references.forall { + case FileSourceConstantMetadataAttribute(metadataAttr) => + // we only know block start and length after splitting files, so skip it here + metadataAttr.name != FileFormat.FILE_BLOCK_START && + metadataAttr.name != FileFormat.FILE_BLOCK_LENGTH + case _ => false + } + }.reduceOption(And) + + // - Retrieve all required metadata attributes and put them into a sequence + // - Bind all file constant metadata attribute references to their respective index + val requiredMetadataColumnNames: mutable.Buffer[String] = mutable.Buffer.empty + val boundedFilterMetadataStructOpt = fileMetadataFilterOpt.map { fileMetadataFilter => + Predicate.createInterpreted(fileMetadataFilter.transform { + case attr: AttributeReference => + val existingMetadataColumnIndex = requiredMetadataColumnNames.indexOf(attr.name) + val metadataColumnIndex = if (existingMetadataColumnIndex >= 0) { + existingMetadataColumnIndex + } else { + requiredMetadataColumnNames += attr.name + requiredMetadataColumnNames.length - 1 + } + BoundReference(metadataColumnIndex, attr.dataType, nullable = true) + }) + } + + private def matchFileMetadataPredicate(partitionValues: InternalRow, f: FileStatus): Boolean = { + // use option.forall, so if there is no filter no metadata struct, return true + boundedFilterMetadataStructOpt.forall { boundedFilter => + val row = + FileFormat.createMetadataInternalRow(partitionValues, requiredMetadataColumnNames.toSeq, + SparkPath.fromFileStatus(f), f.getLen, f.getModificationTime) + boundedFilter.eval(row) + } + } + + def prune(pd: PartitionDirectory): PartitionDirectory = { + val prunedFiles = pd.files.filter { f => + matchFileMetadataPredicate(InternalRow.empty, f.fileStatus) + } + pd.copy(files = prunedFiles) + } +} + object PartitionDirectory { // For backward compat with code that does not know about extra file metadata def apply(values: InternalRow, files: Array[FileStatus]): PartitionDirectory = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 5673e12927c70..551fe253657c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -174,11 +174,7 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging { val bucketSpec: Option[BucketSpec] = fsRelation.bucketSpec val bucketSet = if (shouldPruneBuckets(bucketSpec)) { - // subquery expressions are filtered out because they can't be used to prune buckets - // as data filters, yet they would be executed - val normalizedFiltersWithoutSubqueries = - normalizedFilters.filterNot(SubqueryExpression.hasSubquery) - genBucketSet(normalizedFiltersWithoutSubqueries, bucketSpec.get) + genBucketSet(normalizedFilters, bucketSpec.get) } else { None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index b25162aad9a77..ef4fff2360097 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.execution.datasources.FileFormat.createMetadataInternalRow import org.apache.spark.sql.types.StructType /** @@ -73,48 +72,10 @@ abstract class PartitioningAwareFileIndex( isDataPath(f.getPath) && f.getLen > 0 } - // retrieve the file constant metadata filters and reduce to a final filter expression that can - // be applied to files. - val fileMetadataFilterOpt = dataFilters.filter { f => - f.references.nonEmpty && f.references.forall { - case FileSourceConstantMetadataAttribute(metadataAttr) => - // we only know block start and length after splitting files, so skip it here - metadataAttr.name != FileFormat.FILE_BLOCK_START && - metadataAttr.name != FileFormat.FILE_BLOCK_LENGTH - case _ => false - } - }.reduceOption(expressions.And) - - // - Retrieve all required metadata attributes and put them into a sequence - // - Bind all file constant metadata attribute references to their respective index - val requiredMetadataColumnNames: mutable.Buffer[String] = mutable.Buffer.empty - val boundedFilterMetadataStructOpt = fileMetadataFilterOpt.map { fileMetadataFilter => - Predicate.createInterpreted(fileMetadataFilter.transform { - case attr: AttributeReference => - val existingMetadataColumnIndex = requiredMetadataColumnNames.indexOf(attr.name) - val metadataColumnIndex = if (existingMetadataColumnIndex >= 0) { - existingMetadataColumnIndex - } else { - requiredMetadataColumnNames += attr.name - requiredMetadataColumnNames.length - 1 - } - BoundReference(metadataColumnIndex, attr.dataType, nullable = true) - }) - } - - def matchFileMetadataPredicate(partitionValues: InternalRow, f: FileStatus): Boolean = { - // use option.forall, so if there is no filter no metadata struct, return true - boundedFilterMetadataStructOpt.forall { boundedFilter => - val row = - createMetadataInternalRow(partitionValues, requiredMetadataColumnNames.toSeq, - SparkPath.fromFileStatus(f), f.getLen, f.getModificationTime) - boundedFilter.eval(row) - } - } - + val filePruningRunner = new FilePruningRunner(dataFilters) val selectedPartitions = if (partitionSpec().partitionColumns.isEmpty) { - PartitionDirectory(InternalRow.empty, allFiles().toArray - .filter(f => isNonEmptyFile(f) && matchFileMetadataPredicate(InternalRow.empty, f))) :: Nil + filePruningRunner.prune( + PartitionDirectory(InternalRow.empty, allFiles().toArray.filter(isNonEmptyFile))) :: Nil } else { if (recursiveFileLookup) { throw new IllegalArgumentException( @@ -125,14 +86,13 @@ abstract class PartitioningAwareFileIndex( val files: Seq[FileStatus] = leafDirToChildrenFiles.get(path) match { case Some(existingDir) => // Directory has children files in it, return them - existingDir.filter(f => matchPathPattern(f) && isNonEmptyFile(f) && - matchFileMetadataPredicate(values, f)) + existingDir.filter(f => matchPathPattern(f) && isNonEmptyFile(f)) case None => // Directory does not exist, or has no children files Nil } - PartitionDirectory(values, files.toArray) + filePruningRunner.prune(PartitionDirectory(values, files.toArray)) } } logTrace("Selected files after partition pruning:\n\t" + selectedPartitions.mkString("\n\t")) diff --git a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out index 44b2679f89d86..7dfaaea46b75d 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out @@ -398,7 +398,7 @@ AdaptiveSparkPlan (3) Output [2]: [key#x, val#x] Batched: true Location [not included in comparison]/{warehouse_dir}/explain_temp1] -PushedFilters: [IsNotNull(key), IsNotNull(val), GreaterThan(val,3)] +PushedFilters: [IsNotNull(key), IsNotNull(val), EqualTo(key,ScalarSubquery#x), GreaterThan(val,3)] ReadSchema: struct (2) Filter @@ -425,7 +425,7 @@ AdaptiveSparkPlan (10) Output [2]: [key#x, val#x] Batched: true Location [not included in comparison]/{warehouse_dir}/explain_temp2] -PushedFilters: [IsNotNull(key), IsNotNull(val), EqualTo(val,2)] +PushedFilters: [IsNotNull(key), IsNotNull(val), EqualTo(key,ScalarSubquery#x), EqualTo(val,2)] ReadSchema: struct (5) Filter diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out index 0cd94abc9b307..ef4d57735aa39 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out @@ -405,7 +405,7 @@ struct Output [2]: [key#x, val#x] Batched: true Location [not included in comparison]/{warehouse_dir}/explain_temp1] -PushedFilters: [IsNotNull(key), IsNotNull(val), GreaterThan(val,3)] +PushedFilters: [IsNotNull(key), IsNotNull(val), EqualTo(key,ScalarSubquery#x), GreaterThan(val,3)] ReadSchema: struct (2) ColumnarToRow [codegen id : 1] @@ -433,7 +433,7 @@ Subquery:2 Hosting operator id = 1 Hosting Expression = Subquery scalar-subquery Output [2]: [key#x, val#x] Batched: true Location [not included in comparison]/{warehouse_dir}/explain_temp2] -PushedFilters: [IsNotNull(key), IsNotNull(val), EqualTo(val,2)] +PushedFilters: [IsNotNull(key), IsNotNull(val), EqualTo(key,ScalarSubquery#x), EqualTo(val,2)] ReadSchema: struct (5) ColumnarToRow [codegen id : 1] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b.sf100/explain.txt index 16bdfb1041619..0986e92088caa 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b.sf100/explain.txt @@ -648,7 +648,7 @@ BroadcastExchange (114) Output [2]: [d_date_sk#36, d_week_seq#100] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_week_seq), IsNotNull(d_date_sk)] +PushedFilters: [IsNotNull(d_week_seq), EqualTo(d_week_seq,ScalarSubquery#101), IsNotNull(d_date_sk)] ReadSchema: struct (111) ColumnarToRow [codegen id : 1] @@ -741,7 +741,7 @@ BroadcastExchange (128) Output [2]: [d_date_sk#60, d_week_seq#108] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_week_seq), IsNotNull(d_date_sk)] +PushedFilters: [IsNotNull(d_week_seq), EqualTo(d_week_seq,ScalarSubquery#109), IsNotNull(d_date_sk)] ReadSchema: struct (125) ColumnarToRow [codegen id : 1] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b/explain.txt index cc8b88f3adcbf..3f4f3653371d9 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b/explain.txt @@ -618,7 +618,7 @@ BroadcastExchange (108) Output [2]: [d_date_sk#40, d_week_seq#100] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_week_seq), IsNotNull(d_date_sk)] +PushedFilters: [IsNotNull(d_week_seq), EqualTo(d_week_seq,ScalarSubquery#101), IsNotNull(d_date_sk)] ReadSchema: struct (105) ColumnarToRow [codegen id : 1] @@ -711,7 +711,7 @@ BroadcastExchange (122) Output [2]: [d_date_sk#64, d_week_seq#108] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_week_seq), IsNotNull(d_date_sk)] +PushedFilters: [IsNotNull(d_week_seq), EqualTo(d_week_seq,ScalarSubquery#109), IsNotNull(d_date_sk)] ReadSchema: struct (119) ColumnarToRow [codegen id : 1] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q54.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q54.sf100/explain.txt index 19643cccab639..572452c72529e 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q54.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q54.sf100/explain.txt @@ -387,7 +387,7 @@ BroadcastExchange (69) Output [2]: [d_date_sk#29, d_month_seq#41] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_month_seq), IsNotNull(d_date_sk)] +PushedFilters: [IsNotNull(d_month_seq), GreaterThanOrEqual(d_month_seq,ScalarSubquery#42), LessThanOrEqual(d_month_seq,ScalarSubquery#43), IsNotNull(d_date_sk)] ReadSchema: struct (66) ColumnarToRow [codegen id : 1] @@ -395,7 +395,7 @@ Input [2]: [d_date_sk#29, d_month_seq#41] (67) Filter [codegen id : 1] Input [2]: [d_date_sk#29, d_month_seq#41] -Condition : (((isnotnull(d_month_seq#41) AND (d_month_seq#41 >= ReusedSubquery Subquery scalar-subquery#42, [id=#43])) AND (d_month_seq#41 <= ReusedSubquery Subquery scalar-subquery#44, [id=#45])) AND isnotnull(d_date_sk#29)) +Condition : (((isnotnull(d_month_seq#41) AND (d_month_seq#41 >= ReusedSubquery Subquery scalar-subquery#42, [id=#44])) AND (d_month_seq#41 <= ReusedSubquery Subquery scalar-subquery#43, [id=#45])) AND isnotnull(d_date_sk#29)) (68) Project [codegen id : 1] Output [1]: [d_date_sk#29] @@ -405,11 +405,11 @@ Input [2]: [d_date_sk#29, d_month_seq#41] Input [1]: [d_date_sk#29] Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [plan_id=9] -Subquery:4 Hosting operator id = 67 Hosting Expression = ReusedSubquery Subquery scalar-subquery#42, [id=#43] +Subquery:4 Hosting operator id = 67 Hosting Expression = ReusedSubquery Subquery scalar-subquery#42, [id=#44] -Subquery:5 Hosting operator id = 67 Hosting Expression = ReusedSubquery Subquery scalar-subquery#44, [id=#45] +Subquery:5 Hosting operator id = 67 Hosting Expression = ReusedSubquery Subquery scalar-subquery#43, [id=#45] -Subquery:6 Hosting operator id = 65 Hosting Expression = Subquery scalar-subquery#42, [id=#43] +Subquery:6 Hosting operator id = 65 Hosting Expression = Subquery scalar-subquery#42, [id=#44] * HashAggregate (76) +- Exchange (75) +- * HashAggregate (74) @@ -455,7 +455,7 @@ Functions: [] Aggregate Attributes: [] Results [1]: [(d_month_seq + 1)#49] -Subquery:7 Hosting operator id = 65 Hosting Expression = Subquery scalar-subquery#44, [id=#45] +Subquery:7 Hosting operator id = 65 Hosting Expression = Subquery scalar-subquery#43, [id=#45] * HashAggregate (83) +- Exchange (82) +- * HashAggregate (81) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q54/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q54/explain.txt index cefaff0c09d39..502d4f3ee6ab3 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q54/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q54/explain.txt @@ -372,7 +372,7 @@ BroadcastExchange (66) Output [2]: [d_date_sk#29, d_month_seq#41] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_month_seq), IsNotNull(d_date_sk)] +PushedFilters: [IsNotNull(d_month_seq), GreaterThanOrEqual(d_month_seq,ScalarSubquery#42), LessThanOrEqual(d_month_seq,ScalarSubquery#43), IsNotNull(d_date_sk)] ReadSchema: struct (63) ColumnarToRow [codegen id : 1] @@ -380,7 +380,7 @@ Input [2]: [d_date_sk#29, d_month_seq#41] (64) Filter [codegen id : 1] Input [2]: [d_date_sk#29, d_month_seq#41] -Condition : (((isnotnull(d_month_seq#41) AND (d_month_seq#41 >= ReusedSubquery Subquery scalar-subquery#42, [id=#43])) AND (d_month_seq#41 <= ReusedSubquery Subquery scalar-subquery#44, [id=#45])) AND isnotnull(d_date_sk#29)) +Condition : (((isnotnull(d_month_seq#41) AND (d_month_seq#41 >= ReusedSubquery Subquery scalar-subquery#42, [id=#44])) AND (d_month_seq#41 <= ReusedSubquery Subquery scalar-subquery#43, [id=#45])) AND isnotnull(d_date_sk#29)) (65) Project [codegen id : 1] Output [1]: [d_date_sk#29] @@ -390,11 +390,11 @@ Input [2]: [d_date_sk#29, d_month_seq#41] Input [1]: [d_date_sk#29] Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [plan_id=10] -Subquery:4 Hosting operator id = 64 Hosting Expression = ReusedSubquery Subquery scalar-subquery#42, [id=#43] +Subquery:4 Hosting operator id = 64 Hosting Expression = ReusedSubquery Subquery scalar-subquery#42, [id=#44] -Subquery:5 Hosting operator id = 64 Hosting Expression = ReusedSubquery Subquery scalar-subquery#44, [id=#45] +Subquery:5 Hosting operator id = 64 Hosting Expression = ReusedSubquery Subquery scalar-subquery#43, [id=#45] -Subquery:6 Hosting operator id = 62 Hosting Expression = Subquery scalar-subquery#42, [id=#43] +Subquery:6 Hosting operator id = 62 Hosting Expression = Subquery scalar-subquery#42, [id=#44] * HashAggregate (73) +- Exchange (72) +- * HashAggregate (71) @@ -440,7 +440,7 @@ Functions: [] Aggregate Attributes: [] Results [1]: [(d_month_seq + 1)#49] -Subquery:7 Hosting operator id = 62 Hosting Expression = Subquery scalar-subquery#44, [id=#45] +Subquery:7 Hosting operator id = 62 Hosting Expression = Subquery scalar-subquery#43, [id=#45] * HashAggregate (80) +- Exchange (79) +- * HashAggregate (78) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q58.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q58.sf100/explain.txt index 26ffe2e0b323e..d9083741a88e7 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q58.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q58.sf100/explain.txt @@ -320,7 +320,7 @@ Condition : isnotnull(d_date_sk#5) Output [2]: [d_date#40, d_week_seq#41] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_week_seq)] +PushedFilters: [IsNotNull(d_week_seq), EqualTo(d_week_seq,ScalarSubquery#42)] ReadSchema: struct (54) ColumnarToRow [codegen id : 1] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q58/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q58/explain.txt index cdb5e45f66872..7f95e52cb8df5 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q58/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q58/explain.txt @@ -320,7 +320,7 @@ Condition : isnotnull(d_date_sk#7) Output [2]: [d_date#40, d_week_seq#41] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_week_seq)] +PushedFilters: [IsNotNull(d_week_seq), EqualTo(d_week_seq,ScalarSubquery#42)] ReadSchema: struct (54) ColumnarToRow [codegen id : 1] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q6.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q6.sf100/explain.txt index 93db1e57839df..ac69497fb26ca 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q6.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q6.sf100/explain.txt @@ -272,7 +272,7 @@ BroadcastExchange (50) Output [2]: [d_date_sk#16, d_month_seq#26] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_month_seq), IsNotNull(d_date_sk)] +PushedFilters: [IsNotNull(d_month_seq), EqualTo(d_month_seq,ScalarSubquery#27), IsNotNull(d_date_sk)] ReadSchema: struct (47) ColumnarToRow [codegen id : 1] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q6/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q6/explain.txt index bd5bdfb666100..75644fea091fe 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q6/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q6/explain.txt @@ -242,7 +242,7 @@ BroadcastExchange (44) Output [2]: [d_date_sk#9, d_month_seq#26] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_month_seq), IsNotNull(d_date_sk)] +PushedFilters: [IsNotNull(d_month_seq), EqualTo(d_month_seq,ScalarSubquery#27), IsNotNull(d_date_sk)] ReadSchema: struct (41) ColumnarToRow [codegen id : 1] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q64/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q64/explain.txt index 620bab62bf16d..69023c88202af 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q64/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q64/explain.txt @@ -760,15 +760,15 @@ Input [4]: [cs_item_sk#122, sum#123, sum#124, isEmpty#125] Keys [1]: [cs_item_sk#122] Functions [2]: [sum(UnscaledValue(cs_ext_list_price#126)), sum(((cr_refunded_cash#127 + cr_reversed_charge#128) + cr_store_credit#129))] Aggregate Attributes [2]: [sum(UnscaledValue(cs_ext_list_price#126))#33, sum(((cr_refunded_cash#127 + cr_reversed_charge#128) + cr_store_credit#129))#34] -Results [3]: [cs_item_sk#122, MakeDecimal(sum(UnscaledValue(cs_ext_list_price#126))#33,17,2) AS sale#35, sum(((cr_refunded_cash#127 + cr_reversed_charge#128) + cr_store_credit#129))#34 AS refund#36] +Results [3]: [cs_item_sk#122, MakeDecimal(sum(UnscaledValue(cs_ext_list_price#126))#33,17,2) AS sale#130, sum(((cr_refunded_cash#127 + cr_reversed_charge#128) + cr_store_credit#129))#34 AS refund#131] (126) Filter [codegen id : 35] -Input [3]: [cs_item_sk#122, sale#35, refund#36] -Condition : ((isnotnull(sale#35) AND isnotnull(refund#36)) AND (cast(sale#35 as decimal(21,2)) > (2 * refund#36))) +Input [3]: [cs_item_sk#122, sale#130, refund#131] +Condition : ((isnotnull(sale#130) AND isnotnull(refund#131)) AND (cast(sale#130 as decimal(21,2)) > (2 * refund#131))) (127) Project [codegen id : 35] Output [1]: [cs_item_sk#122] -Input [3]: [cs_item_sk#122, sale#35, refund#36] +Input [3]: [cs_item_sk#122, sale#130, refund#131] (128) Sort [codegen id : 35] Input [1]: [cs_item_sk#122] @@ -785,239 +785,239 @@ Output [11]: [ss_item_sk#106, ss_customer_sk#107, ss_cdemo_sk#108, ss_hdemo_sk#1 Input [12]: [ss_item_sk#106, ss_customer_sk#107, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_store_sk#111, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, ss_sold_date_sk#117, cs_item_sk#122] (131) ReusedExchange [Reuses operator id: 191] -Output [2]: [d_date_sk#130, d_year#131] +Output [2]: [d_date_sk#132, d_year#133] (132) BroadcastHashJoin [codegen id : 51] Left keys [1]: [ss_sold_date_sk#117] -Right keys [1]: [d_date_sk#130] +Right keys [1]: [d_date_sk#132] Join type: Inner Join condition: None (133) Project [codegen id : 51] -Output [11]: [ss_item_sk#106, ss_customer_sk#107, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_store_sk#111, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131] -Input [13]: [ss_item_sk#106, ss_customer_sk#107, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_store_sk#111, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, ss_sold_date_sk#117, d_date_sk#130, d_year#131] +Output [11]: [ss_item_sk#106, ss_customer_sk#107, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_store_sk#111, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133] +Input [13]: [ss_item_sk#106, ss_customer_sk#107, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_store_sk#111, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, ss_sold_date_sk#117, d_date_sk#132, d_year#133] (134) ReusedExchange [Reuses operator id: 41] -Output [3]: [s_store_sk#132, s_store_name#133, s_zip#134] +Output [3]: [s_store_sk#134, s_store_name#135, s_zip#136] (135) BroadcastHashJoin [codegen id : 51] Left keys [1]: [ss_store_sk#111] -Right keys [1]: [s_store_sk#132] +Right keys [1]: [s_store_sk#134] Join type: Inner Join condition: None (136) Project [codegen id : 51] -Output [12]: [ss_item_sk#106, ss_customer_sk#107, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134] -Input [14]: [ss_item_sk#106, ss_customer_sk#107, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_store_sk#111, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_sk#132, s_store_name#133, s_zip#134] +Output [12]: [ss_item_sk#106, ss_customer_sk#107, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136] +Input [14]: [ss_item_sk#106, ss_customer_sk#107, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_store_sk#111, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_sk#134, s_store_name#135, s_zip#136] (137) ReusedExchange [Reuses operator id: 47] -Output [6]: [c_customer_sk#135, c_current_cdemo_sk#136, c_current_hdemo_sk#137, c_current_addr_sk#138, c_first_shipto_date_sk#139, c_first_sales_date_sk#140] +Output [6]: [c_customer_sk#137, c_current_cdemo_sk#138, c_current_hdemo_sk#139, c_current_addr_sk#140, c_first_shipto_date_sk#141, c_first_sales_date_sk#142] (138) BroadcastHashJoin [codegen id : 51] Left keys [1]: [ss_customer_sk#107] -Right keys [1]: [c_customer_sk#135] +Right keys [1]: [c_customer_sk#137] Join type: Inner Join condition: None (139) Project [codegen id : 51] -Output [16]: [ss_item_sk#106, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_cdemo_sk#136, c_current_hdemo_sk#137, c_current_addr_sk#138, c_first_shipto_date_sk#139, c_first_sales_date_sk#140] -Input [18]: [ss_item_sk#106, ss_customer_sk#107, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_customer_sk#135, c_current_cdemo_sk#136, c_current_hdemo_sk#137, c_current_addr_sk#138, c_first_shipto_date_sk#139, c_first_sales_date_sk#140] +Output [16]: [ss_item_sk#106, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_cdemo_sk#138, c_current_hdemo_sk#139, c_current_addr_sk#140, c_first_shipto_date_sk#141, c_first_sales_date_sk#142] +Input [18]: [ss_item_sk#106, ss_customer_sk#107, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_customer_sk#137, c_current_cdemo_sk#138, c_current_hdemo_sk#139, c_current_addr_sk#140, c_first_shipto_date_sk#141, c_first_sales_date_sk#142] (140) ReusedExchange [Reuses operator id: 53] -Output [2]: [d_date_sk#141, d_year#142] +Output [2]: [d_date_sk#143, d_year#144] (141) BroadcastHashJoin [codegen id : 51] -Left keys [1]: [c_first_sales_date_sk#140] -Right keys [1]: [d_date_sk#141] +Left keys [1]: [c_first_sales_date_sk#142] +Right keys [1]: [d_date_sk#143] Join type: Inner Join condition: None (142) Project [codegen id : 51] -Output [16]: [ss_item_sk#106, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_cdemo_sk#136, c_current_hdemo_sk#137, c_current_addr_sk#138, c_first_shipto_date_sk#139, d_year#142] -Input [18]: [ss_item_sk#106, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_cdemo_sk#136, c_current_hdemo_sk#137, c_current_addr_sk#138, c_first_shipto_date_sk#139, c_first_sales_date_sk#140, d_date_sk#141, d_year#142] +Output [16]: [ss_item_sk#106, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_cdemo_sk#138, c_current_hdemo_sk#139, c_current_addr_sk#140, c_first_shipto_date_sk#141, d_year#144] +Input [18]: [ss_item_sk#106, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_cdemo_sk#138, c_current_hdemo_sk#139, c_current_addr_sk#140, c_first_shipto_date_sk#141, c_first_sales_date_sk#142, d_date_sk#143, d_year#144] (143) ReusedExchange [Reuses operator id: 53] -Output [2]: [d_date_sk#143, d_year#144] +Output [2]: [d_date_sk#145, d_year#146] (144) BroadcastHashJoin [codegen id : 51] -Left keys [1]: [c_first_shipto_date_sk#139] -Right keys [1]: [d_date_sk#143] +Left keys [1]: [c_first_shipto_date_sk#141] +Right keys [1]: [d_date_sk#145] Join type: Inner Join condition: None (145) Project [codegen id : 51] -Output [16]: [ss_item_sk#106, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_cdemo_sk#136, c_current_hdemo_sk#137, c_current_addr_sk#138, d_year#142, d_year#144] -Input [18]: [ss_item_sk#106, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_cdemo_sk#136, c_current_hdemo_sk#137, c_current_addr_sk#138, c_first_shipto_date_sk#139, d_year#142, d_date_sk#143, d_year#144] +Output [16]: [ss_item_sk#106, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_cdemo_sk#138, c_current_hdemo_sk#139, c_current_addr_sk#140, d_year#144, d_year#146] +Input [18]: [ss_item_sk#106, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_cdemo_sk#138, c_current_hdemo_sk#139, c_current_addr_sk#140, c_first_shipto_date_sk#141, d_year#144, d_date_sk#145, d_year#146] (146) ReusedExchange [Reuses operator id: 62] -Output [2]: [cd_demo_sk#145, cd_marital_status#146] +Output [2]: [cd_demo_sk#147, cd_marital_status#148] (147) BroadcastHashJoin [codegen id : 51] Left keys [1]: [ss_cdemo_sk#108] -Right keys [1]: [cd_demo_sk#145] +Right keys [1]: [cd_demo_sk#147] Join type: Inner Join condition: None (148) Project [codegen id : 51] -Output [16]: [ss_item_sk#106, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_cdemo_sk#136, c_current_hdemo_sk#137, c_current_addr_sk#138, d_year#142, d_year#144, cd_marital_status#146] -Input [18]: [ss_item_sk#106, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_cdemo_sk#136, c_current_hdemo_sk#137, c_current_addr_sk#138, d_year#142, d_year#144, cd_demo_sk#145, cd_marital_status#146] +Output [16]: [ss_item_sk#106, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_cdemo_sk#138, c_current_hdemo_sk#139, c_current_addr_sk#140, d_year#144, d_year#146, cd_marital_status#148] +Input [18]: [ss_item_sk#106, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_cdemo_sk#138, c_current_hdemo_sk#139, c_current_addr_sk#140, d_year#144, d_year#146, cd_demo_sk#147, cd_marital_status#148] (149) ReusedExchange [Reuses operator id: 62] -Output [2]: [cd_demo_sk#147, cd_marital_status#148] +Output [2]: [cd_demo_sk#149, cd_marital_status#150] (150) BroadcastHashJoin [codegen id : 51] -Left keys [1]: [c_current_cdemo_sk#136] -Right keys [1]: [cd_demo_sk#147] +Left keys [1]: [c_current_cdemo_sk#138] +Right keys [1]: [cd_demo_sk#149] Join type: Inner -Join condition: NOT (cd_marital_status#146 = cd_marital_status#148) +Join condition: NOT (cd_marital_status#148 = cd_marital_status#150) (151) Project [codegen id : 51] -Output [14]: [ss_item_sk#106, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_hdemo_sk#137, c_current_addr_sk#138, d_year#142, d_year#144] -Input [18]: [ss_item_sk#106, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_cdemo_sk#136, c_current_hdemo_sk#137, c_current_addr_sk#138, d_year#142, d_year#144, cd_marital_status#146, cd_demo_sk#147, cd_marital_status#148] +Output [14]: [ss_item_sk#106, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_hdemo_sk#139, c_current_addr_sk#140, d_year#144, d_year#146] +Input [18]: [ss_item_sk#106, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_cdemo_sk#138, c_current_hdemo_sk#139, c_current_addr_sk#140, d_year#144, d_year#146, cd_marital_status#148, cd_demo_sk#149, cd_marital_status#150] (152) ReusedExchange [Reuses operator id: 71] -Output [1]: [p_promo_sk#149] +Output [1]: [p_promo_sk#151] (153) BroadcastHashJoin [codegen id : 51] Left keys [1]: [ss_promo_sk#112] -Right keys [1]: [p_promo_sk#149] +Right keys [1]: [p_promo_sk#151] Join type: Inner Join condition: None (154) Project [codegen id : 51] -Output [13]: [ss_item_sk#106, ss_hdemo_sk#109, ss_addr_sk#110, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_hdemo_sk#137, c_current_addr_sk#138, d_year#142, d_year#144] -Input [15]: [ss_item_sk#106, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_hdemo_sk#137, c_current_addr_sk#138, d_year#142, d_year#144, p_promo_sk#149] +Output [13]: [ss_item_sk#106, ss_hdemo_sk#109, ss_addr_sk#110, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_hdemo_sk#139, c_current_addr_sk#140, d_year#144, d_year#146] +Input [15]: [ss_item_sk#106, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_hdemo_sk#139, c_current_addr_sk#140, d_year#144, d_year#146, p_promo_sk#151] (155) ReusedExchange [Reuses operator id: 77] -Output [2]: [hd_demo_sk#150, hd_income_band_sk#151] +Output [2]: [hd_demo_sk#152, hd_income_band_sk#153] (156) BroadcastHashJoin [codegen id : 51] Left keys [1]: [ss_hdemo_sk#109] -Right keys [1]: [hd_demo_sk#150] +Right keys [1]: [hd_demo_sk#152] Join type: Inner Join condition: None (157) Project [codegen id : 51] -Output [13]: [ss_item_sk#106, ss_addr_sk#110, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_hdemo_sk#137, c_current_addr_sk#138, d_year#142, d_year#144, hd_income_band_sk#151] -Input [15]: [ss_item_sk#106, ss_hdemo_sk#109, ss_addr_sk#110, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_hdemo_sk#137, c_current_addr_sk#138, d_year#142, d_year#144, hd_demo_sk#150, hd_income_band_sk#151] +Output [13]: [ss_item_sk#106, ss_addr_sk#110, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_hdemo_sk#139, c_current_addr_sk#140, d_year#144, d_year#146, hd_income_band_sk#153] +Input [15]: [ss_item_sk#106, ss_hdemo_sk#109, ss_addr_sk#110, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_hdemo_sk#139, c_current_addr_sk#140, d_year#144, d_year#146, hd_demo_sk#152, hd_income_band_sk#153] (158) ReusedExchange [Reuses operator id: 77] -Output [2]: [hd_demo_sk#152, hd_income_band_sk#153] +Output [2]: [hd_demo_sk#154, hd_income_band_sk#155] (159) BroadcastHashJoin [codegen id : 51] -Left keys [1]: [c_current_hdemo_sk#137] -Right keys [1]: [hd_demo_sk#152] +Left keys [1]: [c_current_hdemo_sk#139] +Right keys [1]: [hd_demo_sk#154] Join type: Inner Join condition: None (160) Project [codegen id : 51] -Output [13]: [ss_item_sk#106, ss_addr_sk#110, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_addr_sk#138, d_year#142, d_year#144, hd_income_band_sk#151, hd_income_band_sk#153] -Input [15]: [ss_item_sk#106, ss_addr_sk#110, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_hdemo_sk#137, c_current_addr_sk#138, d_year#142, d_year#144, hd_income_band_sk#151, hd_demo_sk#152, hd_income_band_sk#153] +Output [13]: [ss_item_sk#106, ss_addr_sk#110, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_addr_sk#140, d_year#144, d_year#146, hd_income_band_sk#153, hd_income_band_sk#155] +Input [15]: [ss_item_sk#106, ss_addr_sk#110, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_hdemo_sk#139, c_current_addr_sk#140, d_year#144, d_year#146, hd_income_band_sk#153, hd_demo_sk#154, hd_income_band_sk#155] (161) ReusedExchange [Reuses operator id: 86] -Output [5]: [ca_address_sk#154, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158] +Output [5]: [ca_address_sk#156, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160] (162) BroadcastHashJoin [codegen id : 51] Left keys [1]: [ss_addr_sk#110] -Right keys [1]: [ca_address_sk#154] +Right keys [1]: [ca_address_sk#156] Join type: Inner Join condition: None (163) Project [codegen id : 51] -Output [16]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_addr_sk#138, d_year#142, d_year#144, hd_income_band_sk#151, hd_income_band_sk#153, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158] -Input [18]: [ss_item_sk#106, ss_addr_sk#110, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_addr_sk#138, d_year#142, d_year#144, hd_income_band_sk#151, hd_income_band_sk#153, ca_address_sk#154, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158] +Output [16]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_addr_sk#140, d_year#144, d_year#146, hd_income_band_sk#153, hd_income_band_sk#155, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160] +Input [18]: [ss_item_sk#106, ss_addr_sk#110, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_addr_sk#140, d_year#144, d_year#146, hd_income_band_sk#153, hd_income_band_sk#155, ca_address_sk#156, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160] (164) ReusedExchange [Reuses operator id: 86] -Output [5]: [ca_address_sk#159, ca_street_number#160, ca_street_name#161, ca_city#162, ca_zip#163] +Output [5]: [ca_address_sk#161, ca_street_number#162, ca_street_name#163, ca_city#164, ca_zip#165] (165) BroadcastHashJoin [codegen id : 51] -Left keys [1]: [c_current_addr_sk#138] -Right keys [1]: [ca_address_sk#159] +Left keys [1]: [c_current_addr_sk#140] +Right keys [1]: [ca_address_sk#161] Join type: Inner Join condition: None (166) Project [codegen id : 51] -Output [19]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, d_year#142, d_year#144, hd_income_band_sk#151, hd_income_band_sk#153, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158, ca_street_number#160, ca_street_name#161, ca_city#162, ca_zip#163] -Input [21]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_addr_sk#138, d_year#142, d_year#144, hd_income_band_sk#151, hd_income_band_sk#153, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158, ca_address_sk#159, ca_street_number#160, ca_street_name#161, ca_city#162, ca_zip#163] +Output [19]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, d_year#144, d_year#146, hd_income_band_sk#153, hd_income_band_sk#155, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160, ca_street_number#162, ca_street_name#163, ca_city#164, ca_zip#165] +Input [21]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_addr_sk#140, d_year#144, d_year#146, hd_income_band_sk#153, hd_income_band_sk#155, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160, ca_address_sk#161, ca_street_number#162, ca_street_name#163, ca_city#164, ca_zip#165] (167) ReusedExchange [Reuses operator id: 95] -Output [1]: [ib_income_band_sk#164] +Output [1]: [ib_income_band_sk#166] (168) BroadcastHashJoin [codegen id : 51] -Left keys [1]: [hd_income_band_sk#151] -Right keys [1]: [ib_income_band_sk#164] +Left keys [1]: [hd_income_band_sk#153] +Right keys [1]: [ib_income_band_sk#166] Join type: Inner Join condition: None (169) Project [codegen id : 51] -Output [18]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, d_year#142, d_year#144, hd_income_band_sk#153, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158, ca_street_number#160, ca_street_name#161, ca_city#162, ca_zip#163] -Input [20]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, d_year#142, d_year#144, hd_income_band_sk#151, hd_income_band_sk#153, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158, ca_street_number#160, ca_street_name#161, ca_city#162, ca_zip#163, ib_income_band_sk#164] +Output [18]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, d_year#144, d_year#146, hd_income_band_sk#155, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160, ca_street_number#162, ca_street_name#163, ca_city#164, ca_zip#165] +Input [20]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, d_year#144, d_year#146, hd_income_band_sk#153, hd_income_band_sk#155, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160, ca_street_number#162, ca_street_name#163, ca_city#164, ca_zip#165, ib_income_band_sk#166] (170) ReusedExchange [Reuses operator id: 95] -Output [1]: [ib_income_band_sk#165] +Output [1]: [ib_income_band_sk#167] (171) BroadcastHashJoin [codegen id : 51] -Left keys [1]: [hd_income_band_sk#153] -Right keys [1]: [ib_income_band_sk#165] +Left keys [1]: [hd_income_band_sk#155] +Right keys [1]: [ib_income_band_sk#167] Join type: Inner Join condition: None (172) Project [codegen id : 51] -Output [17]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, d_year#142, d_year#144, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158, ca_street_number#160, ca_street_name#161, ca_city#162, ca_zip#163] -Input [19]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, d_year#142, d_year#144, hd_income_band_sk#153, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158, ca_street_number#160, ca_street_name#161, ca_city#162, ca_zip#163, ib_income_band_sk#165] +Output [17]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, d_year#144, d_year#146, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160, ca_street_number#162, ca_street_name#163, ca_city#164, ca_zip#165] +Input [19]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, d_year#144, d_year#146, hd_income_band_sk#155, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160, ca_street_number#162, ca_street_name#163, ca_city#164, ca_zip#165, ib_income_band_sk#167] (173) ReusedExchange [Reuses operator id: 105] -Output [2]: [i_item_sk#166, i_product_name#167] +Output [2]: [i_item_sk#168, i_product_name#169] (174) BroadcastHashJoin [codegen id : 51] Left keys [1]: [ss_item_sk#106] -Right keys [1]: [i_item_sk#166] +Right keys [1]: [i_item_sk#168] Join type: Inner Join condition: None (175) Project [codegen id : 51] -Output [18]: [ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, d_year#142, d_year#144, s_store_name#133, s_zip#134, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158, ca_street_number#160, ca_street_name#161, ca_city#162, ca_zip#163, i_item_sk#166, i_product_name#167] -Input [19]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, d_year#142, d_year#144, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158, ca_street_number#160, ca_street_name#161, ca_city#162, ca_zip#163, i_item_sk#166, i_product_name#167] +Output [18]: [ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, d_year#144, d_year#146, s_store_name#135, s_zip#136, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160, ca_street_number#162, ca_street_name#163, ca_city#164, ca_zip#165, i_item_sk#168, i_product_name#169] +Input [19]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, d_year#144, d_year#146, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160, ca_street_number#162, ca_street_name#163, ca_city#164, ca_zip#165, i_item_sk#168, i_product_name#169] (176) HashAggregate [codegen id : 51] -Input [18]: [ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, d_year#142, d_year#144, s_store_name#133, s_zip#134, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158, ca_street_number#160, ca_street_name#161, ca_city#162, ca_zip#163, i_item_sk#166, i_product_name#167] -Keys [15]: [i_product_name#167, i_item_sk#166, s_store_name#133, s_zip#134, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158, ca_street_number#160, ca_street_name#161, ca_city#162, ca_zip#163, d_year#131, d_year#142, d_year#144] +Input [18]: [ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, d_year#144, d_year#146, s_store_name#135, s_zip#136, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160, ca_street_number#162, ca_street_name#163, ca_city#164, ca_zip#165, i_item_sk#168, i_product_name#169] +Keys [15]: [i_product_name#169, i_item_sk#168, s_store_name#135, s_zip#136, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160, ca_street_number#162, ca_street_name#163, ca_city#164, ca_zip#165, d_year#133, d_year#144, d_year#146] Functions [4]: [partial_count(1), partial_sum(UnscaledValue(ss_wholesale_cost#114)), partial_sum(UnscaledValue(ss_list_price#115)), partial_sum(UnscaledValue(ss_coupon_amt#116))] -Aggregate Attributes [4]: [count#77, sum#168, sum#169, sum#170] -Results [19]: [i_product_name#167, i_item_sk#166, s_store_name#133, s_zip#134, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158, ca_street_number#160, ca_street_name#161, ca_city#162, ca_zip#163, d_year#131, d_year#142, d_year#144, count#81, sum#171, sum#172, sum#173] +Aggregate Attributes [4]: [count#77, sum#170, sum#171, sum#172] +Results [19]: [i_product_name#169, i_item_sk#168, s_store_name#135, s_zip#136, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160, ca_street_number#162, ca_street_name#163, ca_city#164, ca_zip#165, d_year#133, d_year#144, d_year#146, count#81, sum#173, sum#174, sum#175] (177) HashAggregate [codegen id : 51] -Input [19]: [i_product_name#167, i_item_sk#166, s_store_name#133, s_zip#134, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158, ca_street_number#160, ca_street_name#161, ca_city#162, ca_zip#163, d_year#131, d_year#142, d_year#144, count#81, sum#171, sum#172, sum#173] -Keys [15]: [i_product_name#167, i_item_sk#166, s_store_name#133, s_zip#134, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158, ca_street_number#160, ca_street_name#161, ca_city#162, ca_zip#163, d_year#131, d_year#142, d_year#144] +Input [19]: [i_product_name#169, i_item_sk#168, s_store_name#135, s_zip#136, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160, ca_street_number#162, ca_street_name#163, ca_city#164, ca_zip#165, d_year#133, d_year#144, d_year#146, count#81, sum#173, sum#174, sum#175] +Keys [15]: [i_product_name#169, i_item_sk#168, s_store_name#135, s_zip#136, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160, ca_street_number#162, ca_street_name#163, ca_city#164, ca_zip#165, d_year#133, d_year#144, d_year#146] Functions [4]: [count(1), sum(UnscaledValue(ss_wholesale_cost#114)), sum(UnscaledValue(ss_list_price#115)), sum(UnscaledValue(ss_coupon_amt#116))] Aggregate Attributes [4]: [count(1)#85, sum(UnscaledValue(ss_wholesale_cost#114))#86, sum(UnscaledValue(ss_list_price#115))#87, sum(UnscaledValue(ss_coupon_amt#116))#88] -Results [8]: [i_item_sk#166 AS item_sk#174, s_store_name#133 AS store_name#175, s_zip#134 AS store_zip#176, d_year#131 AS syear#177, count(1)#85 AS cnt#178, MakeDecimal(sum(UnscaledValue(ss_wholesale_cost#114))#86,17,2) AS s1#179, MakeDecimal(sum(UnscaledValue(ss_list_price#115))#87,17,2) AS s2#180, MakeDecimal(sum(UnscaledValue(ss_coupon_amt#116))#88,17,2) AS s3#181] +Results [8]: [i_item_sk#168 AS item_sk#176, s_store_name#135 AS store_name#177, s_zip#136 AS store_zip#178, d_year#133 AS syear#179, count(1)#85 AS cnt#180, MakeDecimal(sum(UnscaledValue(ss_wholesale_cost#114))#86,17,2) AS s1#181, MakeDecimal(sum(UnscaledValue(ss_list_price#115))#87,17,2) AS s2#182, MakeDecimal(sum(UnscaledValue(ss_coupon_amt#116))#88,17,2) AS s3#183] (178) Exchange -Input [8]: [item_sk#174, store_name#175, store_zip#176, syear#177, cnt#178, s1#179, s2#180, s3#181] -Arguments: hashpartitioning(item_sk#174, store_name#175, store_zip#176, 5), ENSURE_REQUIREMENTS, [plan_id=18] +Input [8]: [item_sk#176, store_name#177, store_zip#178, syear#179, cnt#180, s1#181, s2#182, s3#183] +Arguments: hashpartitioning(item_sk#176, store_name#177, store_zip#178, 5), ENSURE_REQUIREMENTS, [plan_id=18] (179) Sort [codegen id : 52] -Input [8]: [item_sk#174, store_name#175, store_zip#176, syear#177, cnt#178, s1#179, s2#180, s3#181] -Arguments: [item_sk#174 ASC NULLS FIRST, store_name#175 ASC NULLS FIRST, store_zip#176 ASC NULLS FIRST], false, 0 +Input [8]: [item_sk#176, store_name#177, store_zip#178, syear#179, cnt#180, s1#181, s2#182, s3#183] +Arguments: [item_sk#176 ASC NULLS FIRST, store_name#177 ASC NULLS FIRST, store_zip#178 ASC NULLS FIRST], false, 0 (180) SortMergeJoin [codegen id : 53] Left keys [3]: [item_sk#90, store_name#91, store_zip#92] -Right keys [3]: [item_sk#174, store_name#175, store_zip#176] +Right keys [3]: [item_sk#176, store_name#177, store_zip#178] Join type: Inner -Join condition: (cnt#178 <= cnt#102) +Join condition: (cnt#180 <= cnt#102) (181) Project [codegen id : 53] -Output [21]: [product_name#89, store_name#91, store_zip#92, b_street_number#93, b_streen_name#94, b_city#95, b_zip#96, c_street_number#97, c_street_name#98, c_city#99, c_zip#100, syear#101, cnt#102, s1#103, s2#104, s3#105, s1#179, s2#180, s3#181, syear#177, cnt#178] -Input [25]: [product_name#89, item_sk#90, store_name#91, store_zip#92, b_street_number#93, b_streen_name#94, b_city#95, b_zip#96, c_street_number#97, c_street_name#98, c_city#99, c_zip#100, syear#101, cnt#102, s1#103, s2#104, s3#105, item_sk#174, store_name#175, store_zip#176, syear#177, cnt#178, s1#179, s2#180, s3#181] +Output [21]: [product_name#89, store_name#91, store_zip#92, b_street_number#93, b_streen_name#94, b_city#95, b_zip#96, c_street_number#97, c_street_name#98, c_city#99, c_zip#100, syear#101, cnt#102, s1#103, s2#104, s3#105, s1#181, s2#182, s3#183, syear#179, cnt#180] +Input [25]: [product_name#89, item_sk#90, store_name#91, store_zip#92, b_street_number#93, b_streen_name#94, b_city#95, b_zip#96, c_street_number#97, c_street_name#98, c_city#99, c_zip#100, syear#101, cnt#102, s1#103, s2#104, s3#105, item_sk#176, store_name#177, store_zip#178, syear#179, cnt#180, s1#181, s2#182, s3#183] (182) Exchange -Input [21]: [product_name#89, store_name#91, store_zip#92, b_street_number#93, b_streen_name#94, b_city#95, b_zip#96, c_street_number#97, c_street_name#98, c_city#99, c_zip#100, syear#101, cnt#102, s1#103, s2#104, s3#105, s1#179, s2#180, s3#181, syear#177, cnt#178] -Arguments: rangepartitioning(product_name#89 ASC NULLS FIRST, store_name#91 ASC NULLS FIRST, cnt#178 ASC NULLS FIRST, 5), ENSURE_REQUIREMENTS, [plan_id=19] +Input [21]: [product_name#89, store_name#91, store_zip#92, b_street_number#93, b_streen_name#94, b_city#95, b_zip#96, c_street_number#97, c_street_name#98, c_city#99, c_zip#100, syear#101, cnt#102, s1#103, s2#104, s3#105, s1#181, s2#182, s3#183, syear#179, cnt#180] +Arguments: rangepartitioning(product_name#89 ASC NULLS FIRST, store_name#91 ASC NULLS FIRST, cnt#180 ASC NULLS FIRST, 5), ENSURE_REQUIREMENTS, [plan_id=19] (183) Sort [codegen id : 54] -Input [21]: [product_name#89, store_name#91, store_zip#92, b_street_number#93, b_streen_name#94, b_city#95, b_zip#96, c_street_number#97, c_street_name#98, c_city#99, c_zip#100, syear#101, cnt#102, s1#103, s2#104, s3#105, s1#179, s2#180, s3#181, syear#177, cnt#178] -Arguments: [product_name#89 ASC NULLS FIRST, store_name#91 ASC NULLS FIRST, cnt#178 ASC NULLS FIRST], true, 0 +Input [21]: [product_name#89, store_name#91, store_zip#92, b_street_number#93, b_streen_name#94, b_city#95, b_zip#96, c_street_number#97, c_street_name#98, c_city#99, c_zip#100, syear#101, cnt#102, s1#103, s2#104, s3#105, s1#181, s2#182, s3#183, syear#179, cnt#180] +Arguments: [product_name#89 ASC NULLS FIRST, store_name#91 ASC NULLS FIRST, cnt#180 ASC NULLS FIRST], true, 0 ===== Subqueries ===== @@ -1054,21 +1054,21 @@ BroadcastExchange (191) (188) Scan parquet spark_catalog.default.date_dim -Output [2]: [d_date_sk#130, d_year#131] +Output [2]: [d_date_sk#132, d_year#133] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), EqualTo(d_year,2000), IsNotNull(d_date_sk)] ReadSchema: struct (189) ColumnarToRow [codegen id : 1] -Input [2]: [d_date_sk#130, d_year#131] +Input [2]: [d_date_sk#132, d_year#133] (190) Filter [codegen id : 1] -Input [2]: [d_date_sk#130, d_year#131] -Condition : ((isnotnull(d_year#131) AND (d_year#131 = 2000)) AND isnotnull(d_date_sk#130)) +Input [2]: [d_date_sk#132, d_year#133] +Condition : ((isnotnull(d_year#133) AND (d_year#133 = 2000)) AND isnotnull(d_date_sk#132)) (191) BroadcastExchange -Input [2]: [d_date_sk#130, d_year#131] +Input [2]: [d_date_sk#132, d_year#133] Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [plan_id=21] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14.sf100/explain.txt index 1440326b862e9..fafd7fd75cbd7 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14.sf100/explain.txt @@ -648,7 +648,7 @@ BroadcastExchange (114) Output [2]: [d_date_sk#36, d_week_seq#100] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_week_seq), IsNotNull(d_date_sk)] +PushedFilters: [IsNotNull(d_week_seq), EqualTo(d_week_seq,ScalarSubquery#101), IsNotNull(d_date_sk)] ReadSchema: struct (111) ColumnarToRow [codegen id : 1] @@ -741,7 +741,7 @@ BroadcastExchange (128) Output [2]: [d_date_sk#60, d_week_seq#108] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_week_seq), IsNotNull(d_date_sk)] +PushedFilters: [IsNotNull(d_week_seq), EqualTo(d_week_seq,ScalarSubquery#109), IsNotNull(d_date_sk)] ReadSchema: struct (125) ColumnarToRow [codegen id : 1] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14/explain.txt index 1e4ca929b9690..4d69899b3b17a 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14/explain.txt @@ -618,7 +618,7 @@ BroadcastExchange (108) Output [2]: [d_date_sk#40, d_week_seq#100] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_week_seq), IsNotNull(d_date_sk)] +PushedFilters: [IsNotNull(d_week_seq), EqualTo(d_week_seq,ScalarSubquery#101), IsNotNull(d_date_sk)] ReadSchema: struct (105) ColumnarToRow [codegen id : 1] @@ -711,7 +711,7 @@ BroadcastExchange (122) Output [2]: [d_date_sk#64, d_week_seq#108] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_week_seq), IsNotNull(d_date_sk)] +PushedFilters: [IsNotNull(d_week_seq), EqualTo(d_week_seq,ScalarSubquery#109), IsNotNull(d_date_sk)] ReadSchema: struct (119) ColumnarToRow [codegen id : 1] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q6.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q6.sf100/explain.txt index 55bed0dade77f..afdfc51a17dd4 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q6.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q6.sf100/explain.txt @@ -272,7 +272,7 @@ BroadcastExchange (50) Output [2]: [d_date_sk#16, d_month_seq#26] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_month_seq), IsNotNull(d_date_sk)] +PushedFilters: [IsNotNull(d_month_seq), EqualTo(d_month_seq,ScalarSubquery#27), IsNotNull(d_date_sk)] ReadSchema: struct (47) ColumnarToRow [codegen id : 1] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q6/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q6/explain.txt index 6713acc975445..a2638dac56456 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q6/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q6/explain.txt @@ -242,7 +242,7 @@ BroadcastExchange (44) Output [2]: [d_date_sk#9, d_month_seq#26] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_month_seq), IsNotNull(d_date_sk)] +PushedFilters: [IsNotNull(d_month_seq), EqualTo(d_month_seq,ScalarSubquery#27), IsNotNull(d_date_sk)] ReadSchema: struct (41) ColumnarToRow [codegen id : 1] From 6073d721933031a7c44086d1588ee7aa7b9be926 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 4 Aug 2023 19:54:49 +0800 Subject: [PATCH 27/68] [SPARK-44675][INFRA] Increase ReservedCodeCacheSize for release build ### What changes were proposed in this pull request? This PR increases `ReservedCodeCacheSize` to 1g for release build. The current warning and cache size: ``` OpenJDK 64-Bit Server VM warning: CodeCache is full. Compiler has been disabled. OpenJDK 64-Bit Server VM warning: Try increasing the code cache size using -XX:ReservedCodeCacheSize= ``` ``` $ ps -ef UID PID PPID C STIME TTY TIME CMD spark-rm 1 0 0 07:47 pts/0 00:00:00 bash /opt/spark-rm/do-release.sh spark-rm 13 1 0 07:47 ? 00:00:02 gpg-agent --homedir /home/spark-rm/.gnupg --use-standard-socket --daemon spark-rm 15 1 0 07:47 pts/0 00:00:00 bash /opt/spark-rm/release-build.sh package spark-rm 6491 0 0 09:56 pts/1 00:00:00 /bin/sh spark-rm 7809 15 0 10:07 pts/0 00:00:00 bash ./dev/make-distribution.sh --name hadoop3 --mvn /opt/spark-rm/output/spark-3.3.3-bin-hadoop3/ spark-rm 7977 7809 99 10:07 pts/0 00:01:16 /usr/bin/java -Xss128m -Xmx12g -classpath /opt/spark-rm/output/spark-3.3.3-bin-hadoop3/build/apach spark-rm 8205 6491 0 10:08 pts/1 00:00:00 ps -ef $ jinfo -flag ReservedCodeCacheSize 7977 -XX:ReservedCodeCacheSize=251658240 ``` ### Why are the changes needed? Reduce build time. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manual test. Closes #42344 from wangyum/SPARK-44675. Authored-by: Yuming Wang Signed-off-by: Yuming Wang --- dev/create-release/release-build.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index e0588ae934cd2..59e3b69b349b8 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -220,7 +220,7 @@ git clean -d -f -x rm -f .gitignore cd .. -export MAVEN_OPTS="-Xss128m -Xmx12g" +export MAVEN_OPTS="-Xss128m -Xmx12g -XX:ReservedCodeCacheSize=1g" if [[ "$1" == "package" ]]; then # Source and binary tarballs From 84ea6f242e4982187edc0a8f5786e7dc69ec31d7 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Fri, 4 Aug 2023 16:06:57 +0200 Subject: [PATCH 28/68] [SPARK-44656][CONNECT] Make all iterators CloseableIterators ### What changes were proposed in this pull request? This makes sure that all iterators used in Spark Connect scala client are `CloseableIterator`. 1. Makes `CustomSparkConnectBlockingStub.executePlan` return `CloseableIterator` and make all wrappers respect that. 2. Makes `ExecutePlanResponseReattachableIterator` a `CloseableIterator`, with an implementation that will inform the server that query result can be released with ReleaseExecute. 3. Makes `SparkResult.iterator` explicitly a `CloseableIterator`, and also register the `SparkResult.responses` iterator as with the `SparkResultCloseable` cleaner, which will make it close upon GC, if not closed explicitly sooner. 4. Because `Dataset.toLocalIterator` requires a Java iterator, implement a conversion to `java.util.Iterator with AutoCloseable` to be returned there 5. Using `CloseableIterator` consistently everywhere else removes the need to convert between iterator types. ### Why are the changes needed? Properly closeable iterators are needed for resource management, and with reattachable execution to inform server that processing finished. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Exercise current E2E tests. Co-authored-by: Alice Sayutina Closes #42331 from juliuszsompolski/closeable_iterators. Lead-authored-by: Juliusz Sompolski Co-authored-by: Alice Sayutina Signed-off-by: Herman van Hovell --- .../scala/org/apache/spark/sql/Dataset.scala | 4 +- .../org/apache/spark/sql/SparkSession.scala | 21 ++++----- .../connect/client/CloseableIterator.scala | 46 +++++++++++++++++++ .../CustomSparkConnectBlockingStub.scala | 10 ++-- ...cutePlanResponseReattachableIterator.scala | 34 ++++++++------ .../client/GrpcExceptionConverter.scala | 10 +++- .../sql/connect/client/GrpcRetryHandler.scala | 24 ++++++---- .../connect/client/SparkConnectClient.scala | 7 +-- .../sql/connect/client/SparkResult.scala | 37 ++++++++------- .../client/arrow/ArrowDeserializer.scala | 1 + .../client/arrow/ArrowEncoderUtils.scala | 2 - .../client/arrow/ArrowSerializer.scala | 1 + .../apache/spark/sql/ClientE2ETestSuite.scala | 2 +- .../client/arrow/ArrowEncoderSuite.scala | 1 + 14 files changed, 133 insertions(+), 67 deletions(-) create mode 100644 connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 0f7b376955c96..8a7dce3987a44 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2832,7 +2832,7 @@ class Dataset[T] private[sql] ( /** * Returns an iterator that contains all rows in this Dataset. * - * The returned iterator implements [[AutoCloseable]]. For memory management it is better to + * The returned iterator implements [[AutoCloseable]]. For resource management it is better to * close it once you are done. If you don't close it, it and the underlying data will be cleaned * up once the iterator is garbage collected. * @@ -2840,7 +2840,7 @@ class Dataset[T] private[sql] ( * @since 3.4.0 */ def toLocalIterator(): java.util.Iterator[T] = { - collectResult().destructiveIterator + collectResult().destructiveIterator.asJava } /** diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 59f3f3526ab2f..355d7edadc788 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -252,10 +252,8 @@ class SparkSession private[sql] ( .setSql(sqlText) .addAllPosArgs(args.map(toLiteralProto).toIterable.asJava))) val plan = proto.Plan.newBuilder().setCommand(cmd) - val responseSeq = client.execute(plan.build()).asScala.toSeq - - // sequence is a lazy stream, force materialize it to make sure it is consumed. - responseSeq.foreach(_ => ()) + // .toBuffer forces that the iterator is consumed and closed + val responseSeq = client.execute(plan.build()).toBuffer.toSeq val response = responseSeq .find(_.hasSqlCommandResult) @@ -311,10 +309,8 @@ class SparkSession private[sql] ( .setSql(sqlText) .putAllArgs(args.asScala.mapValues(toLiteralProto).toMap.asJava))) val plan = proto.Plan.newBuilder().setCommand(cmd) - val responseSeq = client.execute(plan.build()).asScala.toSeq - - // sequence is a lazy stream, force materialize it to make sure it is consumed. - responseSeq.foreach(_ => ()) + // .toBuffer forces that the iterator is consumed and closed + val responseSeq = client.execute(plan.build()).toBuffer.toSeq val response = responseSeq .find(_.hasSqlCommandResult) @@ -548,15 +544,14 @@ class SparkSession private[sql] ( f(builder) builder.getCommonBuilder.setPlanId(planIdGenerator.getAndIncrement()) val plan = proto.Plan.newBuilder().setRoot(builder).build() - client.execute(plan).asScala.foreach(_ => ()) + // .toBuffer forces that the iterator is consumed and closed + client.execute(plan).toBuffer } private[sql] def execute(command: proto.Command): Seq[ExecutePlanResponse] = { val plan = proto.Plan.newBuilder().setCommand(command).build() - val seq = client.execute(plan).asScala.toSeq - // sequence is a lazy stream, force materialize it to make sure it is consumed. - seq.foreach(_ => ()) - seq + // .toBuffer forces that the iterator is consumed and closed + client.execute(plan).toBuffer.toSeq } private[sql] def registerUdf(udf: proto.CommonInlineUserDefinedFunction): Unit = { diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala new file mode 100644 index 0000000000000..891e50ed6e7bd --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.client + +private[sql] trait CloseableIterator[E] extends Iterator[E] with AutoCloseable { self => + def asJava: java.util.Iterator[E] = new java.util.Iterator[E] with AutoCloseable { + override def next() = self.next() + + override def hasNext() = self.hasNext + + override def close() = self.close() + } +} + +private[sql] object CloseableIterator { + + /** + * Wrap iterator to get CloseeableIterator, if it wasn't closeable already. + */ + def apply[T](iterator: Iterator[T]): CloseableIterator[T] = iterator match { + case closeable: CloseableIterator[T] => closeable + case _ => + new CloseableIterator[T] { + override def next(): T = iterator.next() + + override def hasNext(): Boolean = iterator.hasNext + + override def close() = { /* empty */ } + } + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala index bb20901eade17..73ff01e223f29 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.connect.client +import scala.collection.JavaConverters._ + import io.grpc.ManagedChannel import org.apache.spark.connect.proto._ @@ -27,15 +29,17 @@ private[client] class CustomSparkConnectBlockingStub( private val stub = SparkConnectServiceGrpc.newBlockingStub(channel) private val retryHandler = new GrpcRetryHandler(retryPolicy) - def executePlan(request: ExecutePlanRequest): java.util.Iterator[ExecutePlanResponse] = { + def executePlan(request: ExecutePlanRequest): CloseableIterator[ExecutePlanResponse] = { GrpcExceptionConverter.convert { GrpcExceptionConverter.convertIterator[ExecutePlanResponse]( - retryHandler.RetryIterator(request, stub.executePlan)) + retryHandler.RetryIterator[ExecutePlanRequest, ExecutePlanResponse]( + request, + r => CloseableIterator(stub.executePlan(r).asScala))) } } def executePlanReattachable( - request: ExecutePlanRequest): java.util.Iterator[ExecutePlanResponse] = { + request: ExecutePlanRequest): CloseableIterator[ExecutePlanResponse] = { GrpcExceptionConverter.convert { GrpcExceptionConverter.convertIterator[ExecutePlanResponse]( // Don't use retryHandler - own retry handling is inside. diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala index 41648c3c10048..d412d9b577064 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala @@ -50,7 +50,7 @@ class ExecutePlanResponseReattachableIterator( request: proto.ExecutePlanRequest, channel: ManagedChannel, retryPolicy: GrpcRetryHandler.RetryPolicy) - extends java.util.Iterator[proto.ExecutePlanResponse] + extends CloseableIterator[proto.ExecutePlanResponse] with Logging { val operationId = if (request.hasOperationId) { @@ -90,8 +90,8 @@ class ExecutePlanResponseReattachableIterator( // Initial iterator comes from ExecutePlan request. // Note: This is not retried, because no error would ever be thrown here, and GRPC will only - // throw error on first iterator.hasNext() or iterator.next() - private var iterator: java.util.Iterator[proto.ExecutePlanResponse] = + // throw error on first iter.hasNext() or iter.next() + private var iter: java.util.Iterator[proto.ExecutePlanResponse] = rawBlockingStub.executePlan(initialRequest) override def next(): proto.ExecutePlanResponse = synchronized { @@ -105,11 +105,11 @@ class ExecutePlanResponseReattachableIterator( var firstTry = true val ret = retry { if (firstTry) { - // on first try, we use the existing iterator. + // on first try, we use the existing iter. firstTry = false } else { - // on retry, the iterator is borked, so we need a new one - iterator = rawBlockingStub.reattachExecute(createReattachExecuteRequest()) + // on retry, the iter is borked, so we need a new one + iter = rawBlockingStub.reattachExecute(createReattachExecuteRequest()) } callIter(_.next()) } @@ -138,23 +138,23 @@ class ExecutePlanResponseReattachableIterator( try { retry { if (firstTry) { - // on first try, we use the existing iterator. + // on first try, we use the existing iter. firstTry = false } else { - // on retry, the iterator is borked, so we need a new one - iterator = rawBlockingStub.reattachExecute(createReattachExecuteRequest()) + // on retry, the iter is borked, so we need a new one + iter = rawBlockingStub.reattachExecute(createReattachExecuteRequest()) } var hasNext = callIter(_.hasNext()) // Graceful reattach: - // If iterator ended, but there was no ResultComplete, it means that there is more, + // If iter ended, but there was no ResultComplete, it means that there is more, // and we need to reattach. if (!hasNext && !resultComplete) { do { - iterator = rawBlockingStub.reattachExecute(createReattachExecuteRequest()) + iter = rawBlockingStub.reattachExecute(createReattachExecuteRequest()) assert(!resultComplete) // shouldn't change... hasNext = callIter(_.hasNext()) - // It's possible that the new iterator will be empty, so we need to loop to get another. - // Eventually, there will be a non empty iterator, because there is always a + // It's possible that the new iter will be empty, so we need to loop to get another. + // Eventually, there will be a non empty iter, because there is always a // ResultComplete inserted by the server at the end of the stream. } while (!hasNext) } @@ -167,6 +167,10 @@ class ExecutePlanResponseReattachableIterator( } } + override def close(): Unit = { + releaseAll() + } + /** * Inform the server to release the buffered execution results until and including given result. * @@ -204,7 +208,7 @@ class ExecutePlanResponseReattachableIterator( */ private def callIter[V](iterFun: java.util.Iterator[proto.ExecutePlanResponse] => V) = { try { - iterFun(iterator) + iterFun(iter) } catch { case ex: StatusRuntimeException if StatusProto @@ -217,7 +221,7 @@ class ExecutePlanResponseReattachableIterator( ex) } // Try a new ExecutePlan, and throw upstream for retry. - iterator = rawBlockingStub.executePlan(initialRequest) + iter = rawBlockingStub.executePlan(initialRequest) throw new GrpcRetryHandler.RetryException } } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala index 1a42ec821d84f..7ff3421a5a045 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala @@ -31,8 +31,8 @@ private[client] object GrpcExceptionConverter { } } - def convertIterator[T](iter: java.util.Iterator[T]): java.util.Iterator[T] = { - new java.util.Iterator[T] { + def convertIterator[T](iter: CloseableIterator[T]): CloseableIterator[T] = { + new CloseableIterator[T] { override def hasNext: Boolean = { convert { iter.hasNext @@ -44,6 +44,12 @@ private[client] object GrpcExceptionConverter { iter.next() } } + + override def close(): Unit = { + convert { + iter.close() + } + } } } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala index 47ff975b26756..6dad5b4b3a9b4 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala @@ -45,13 +45,13 @@ private[client] class GrpcRetryHandler(private val retryPolicy: GrpcRetryHandler * @tparam U * The type of the response. */ - class RetryIterator[T, U](request: T, call: T => java.util.Iterator[U]) - extends java.util.Iterator[U] { + class RetryIterator[T, U](request: T, call: T => CloseableIterator[U]) + extends CloseableIterator[U] { private var opened = false // we only retry if it fails on first call when using the iterator - private var iterator = call(request) + private var iter = call(request) - private def retryIter[V](f: java.util.Iterator[U] => V) = { + private def retryIter[V](f: Iterator[U] => V) = { if (!opened) { opened = true var firstTry = true @@ -61,26 +61,30 @@ private[client] class GrpcRetryHandler(private val retryPolicy: GrpcRetryHandler firstTry = false } else { // on retry, we need to call the RPC again. - iterator = call(request) + iter = call(request) } - f(iterator) + f(iter) } } else { - f(iterator) + f(iter) } } override def next: U = { - retryIter(_.next()) + retryIter(_.next) } override def hasNext: Boolean = { - retryIter(_.hasNext()) + retryIter(_.hasNext) + } + + override def close(): Unit = { + iter.close() } } object RetryIterator { - def apply[T, U](request: T, call: T => java.util.Iterator[U]): RetryIterator[T, U] = + def apply[T, U](request: T, call: T => CloseableIterator[U]): RetryIterator[T, U] = new RetryIterator(request, call) } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index 3d20be88888c3..a028df536cf88 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -75,10 +75,11 @@ private[sql] class SparkConnectClient( /** * Execute the plan and return response iterator. * - * It returns an open iterator. The caller needs to ensure that this iterator is fully consumed, - * otherwise resources held by a re-attachable query may be left dangling until server timeout. + * It returns CloseableIterator. For resource management it is better to close it once you are + * done. If you don't close it, it and the underlying data will be cleaned up once the iterator + * is garbage collected. */ - def execute(plan: proto.Plan): java.util.Iterator[proto.ExecutePlanResponse] = { + def execute(plan: proto.Plan): CloseableIterator[proto.ExecutePlanResponse] = { artifactManager.uploadAllClassFileArtifacts() val request = proto.ExecutePlanRequest .newBuilder() diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala index 93c32aa2954a3..609e84779fbfc 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala @@ -27,14 +27,14 @@ import org.apache.arrow.vector.types.pojo import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder, UnboundRowEncoder} -import org.apache.spark.sql.connect.client.arrow.{AbstractMessageIterator, ArrowDeserializingIterator, CloseableIterator, ConcatenatingArrowStreamReader, MessageIterator} +import org.apache.spark.sql.connect.client.arrow.{AbstractMessageIterator, ArrowDeserializingIterator, ConcatenatingArrowStreamReader, MessageIterator} import org.apache.spark.sql.connect.client.util.Cleanable import org.apache.spark.sql.connect.common.DataTypeProtoConverter import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.ArrowUtils private[sql] class SparkResult[T]( - responses: java.util.Iterator[proto.ExecutePlanResponse], + responses: CloseableIterator[proto.ExecutePlanResponse], allocator: BufferAllocator, encoder: AgnosticEncoder[T], timeZoneId: String) @@ -198,22 +198,22 @@ private[sql] class SparkResult[T]( /** * Returns an iterator over the contents of the result. */ - def iterator: java.util.Iterator[T] with AutoCloseable = + def iterator: CloseableIterator[T] = buildIterator(destructive = false) /** * Returns an destructive iterator over the contents of the result. */ - def destructiveIterator: java.util.Iterator[T] with AutoCloseable = + def destructiveIterator: CloseableIterator[T] = buildIterator(destructive = true) - private def buildIterator(destructive: Boolean): java.util.Iterator[T] with AutoCloseable = { - new java.util.Iterator[T] with AutoCloseable { - private[this] var iterator: CloseableIterator[T] = _ + private def buildIterator(destructive: Boolean): CloseableIterator[T] = { + new CloseableIterator[T] { + private[this] var iter: CloseableIterator[T] = _ private def initialize(): Unit = { - if (iterator == null) { - iterator = new ArrowDeserializingIterator( + if (iter == null) { + iter = new ArrowDeserializingIterator( createEncoder(encoder, schema), new ConcatenatingArrowStreamReader( allocator, @@ -225,17 +225,17 @@ private[sql] class SparkResult[T]( override def hasNext: Boolean = { initialize() - iterator.hasNext + iter.hasNext } override def next(): T = { initialize() - iterator.next() + iter.next() } override def close(): Unit = { - if (iterator != null) { - iterator.close() + if (iter != null) { + iter.close() } } } @@ -246,7 +246,7 @@ private[sql] class SparkResult[T]( */ override def close(): Unit = cleaner.close() - override val cleaner: AutoCloseable = new SparkResultCloseable(resultMap) + override val cleaner: AutoCloseable = new SparkResultCloseable(resultMap, responses) private class ResultMessageIterator(destructive: Boolean) extends AbstractMessageIterator { private[this] var totalBytesRead = 0L @@ -296,7 +296,12 @@ private[sql] class SparkResult[T]( } } -private[client] class SparkResultCloseable(resultMap: mutable.Map[Int, (Long, Seq[ArrowMessage])]) +private[client] class SparkResultCloseable( + resultMap: mutable.Map[Int, (Long, Seq[ArrowMessage])], + responses: CloseableIterator[proto.ExecutePlanResponse]) extends AutoCloseable { - override def close(): Unit = resultMap.values.foreach(_._2.foreach(_.close())) + override def close(): Unit = { + resultMap.values.foreach(_._2.foreach(_.close())) + responses.close() + } } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala index 509ceffc55282..55dd640f1b6b1 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema +import org.apache.spark.sql.connect.client.CloseableIterator import org.apache.spark.sql.errors.{CompilationErrors, ExecutionErrors} import org.apache.spark.sql.types.Decimal diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala index ed27336985416..b9badc5c936fa 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala @@ -40,8 +40,6 @@ private[arrow] object ArrowEncoderUtils { } } -trait CloseableIterator[E] extends Iterator[E] with AutoCloseable - private[arrow] object StructVectors { def unapply(v: AnyRef): Option[(StructVector, Seq[FieldVector])] = v match { case root: VectorSchemaRoot => Option((null, root.getFieldVectors.asScala.toSeq)) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala index c4a2cfa8a850f..9e67522711c6e 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.DefinedByConstructorParams import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.util.{SparkDateTimeUtils, SparkIntervalUtils} +import org.apache.spark.sql.connect.client.CloseableIterator import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.types.Decimal import org.apache.spark.sql.util.ArrowUtils diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 1403d460b516f..98fbff84ba674 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -755,7 +755,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM private def checkSameResult[E](expected: scala.collection.Seq[E], dataset: Dataset[E]): Unit = { dataset.withResult { result => - assert(expected === result.iterator.asScala.toBuffer) + assert(expected === result.iterator.toBuffer) } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index dd0e9347ac88b..7a8e8465a70cc 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_SECOND import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils._ import org.apache.spark.sql.catalyst.util.SparkIntervalUtils._ +import org.apache.spark.sql.connect.client.CloseableIterator import org.apache.spark.sql.connect.client.arrow.FooEnum.FooEnum import org.apache.spark.sql.connect.client.util.ConnectFunSuite import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType, Decimal, DecimalType, IntegerType, Metadata, SQLUserDefinedType, StructType, UserDefinedType, YearMonthIntervalType} From 4a344b6a37a85fafb624985ce5282e50fb971866 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Fri, 4 Aug 2023 07:47:53 -0700 Subject: [PATCH 29/68] [SPARK-44674][GRAPHX] Remove `BytecodeUtils` from `graphx` module ### What changes were proposed in this pull request? `BytecodeUtils` and `BytecodeUtilsSuite` introduced in [Added the BytecodeUtils class for analyzing bytecode](https://github.com/apache/spark/commit/ae12d163dc2462ededefc8d31900803cf9a782a5). https://github.com/apache/spark/pull/23098 deleted the `BytecodeUtilsSuite`, and after https://github.com/apache/spark/pull/35566, `BytecodeUtils` is no longer used. So this pr remove `BytecodeUtils` from `graphx` module. ### Why are the changes needed? Clean up unnecessary code. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass GitHub Actions Closes #42343 from LuciferYang/SPARK-44674. Authored-by: yangjie01 Signed-off-by: Dongjoon Hyun --- .../spark/graphx/util/BytecodeUtils.scala | 134 ------------------ 1 file changed, 134 deletions(-) delete mode 100644 graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala deleted file mode 100644 index 3b08b9d62cfce..0000000000000 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.graphx.util - -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} - -import scala.collection.mutable.HashSet - -import org.apache.xbean.asm9.{ClassReader, ClassVisitor, MethodVisitor} -import org.apache.xbean.asm9.Opcodes._ - -import org.apache.spark.util.Utils - -/** - * Includes an utility function to test whether a function accesses a specific attribute - * of an object. - */ -private[graphx] object BytecodeUtils { - - /** - * Test whether the given closure invokes the specified method in the specified class. - */ - def invokedMethod(closure: AnyRef, targetClass: Class[_], targetMethod: String): Boolean = { - if (_invokedMethod(closure.getClass, "apply", targetClass, targetMethod)) { - true - } else { - // look at closures enclosed in this closure - for (f <- closure.getClass.getDeclaredFields - if f.getType.getName.startsWith("scala.Function")) { - f.setAccessible(true) - if (invokedMethod(f.get(closure), targetClass, targetMethod)) { - return true - } - } - false - } - } - - private def _invokedMethod(cls: Class[_], method: String, - targetClass: Class[_], targetMethod: String): Boolean = { - - val seen = new HashSet[(Class[_], String)] - var stack = List[(Class[_], String)]((cls, method)) - - while (stack.nonEmpty) { - val c = stack.head._1 - val m = stack.head._2 - stack = stack.tail - seen.add((c, m)) - val finder = new MethodInvocationFinder(c.getName, m) - getClassReader(c).accept(finder, 0) - for (classMethod <- finder.methodsInvoked) { - if (classMethod._1 == targetClass && classMethod._2 == targetMethod) { - return true - } else if (!seen.contains(classMethod)) { - stack = classMethod :: stack - } - } - } - false - } - - /** - * Get an ASM class reader for a given class from the JAR that loaded it. - */ - private def getClassReader(cls: Class[_]): ClassReader = { - // Copy data over, before delegating to ClassReader - else we can run out of open file handles. - val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" - val resourceStream = cls.getResourceAsStream(className) - // todo: Fixme - continuing with earlier behavior ... - if (resourceStream == null) return new ClassReader(resourceStream) - - val baos = new ByteArrayOutputStream(128) - Utils.copyStream(resourceStream, baos, true) - new ClassReader(new ByteArrayInputStream(baos.toByteArray)) - } - - /** - * Given the class name, return whether we should look into the class or not. This is used to - * skip examining a large quantity of Java or Scala classes that we know for sure wouldn't access - * the closures. Note that the class name is expected in ASM style (i.e. use "/" instead of "."). - */ - private def skipClass(className: String): Boolean = { - val c = className - c.startsWith("java/") || c.startsWith("scala/") || c.startsWith("javax/") - } - - /** - * Find the set of methods invoked by the specified method in the specified class. - * For example, after running the visitor, - * MethodInvocationFinder("spark/graph/Foo", "test") - * its methodsInvoked variable will contain the set of methods invoked directly by - * Foo.test(). Interface invocations are not returned as part of the result set because we cannot - * determine the actual method invoked by inspecting the bytecode. - */ - private class MethodInvocationFinder(className: String, methodName: String) - extends ClassVisitor(ASM9) { - - val methodsInvoked = new HashSet[(Class[_], String)] - - override def visitMethod(access: Int, name: String, desc: String, - sig: String, exceptions: Array[String]): MethodVisitor = { - if (name == methodName) { - new MethodVisitor(ASM9) { - override def visitMethodInsn( - op: Int, owner: String, name: String, desc: String, itf: Boolean): Unit = { - if (op == INVOKEVIRTUAL || op == INVOKESPECIAL || op == INVOKESTATIC) { - if (!skipClass(owner)) { - methodsInvoked.add((Utils.classForName(owner.replace("/", ".")), name)) - } - } - } - } - } else { - null - } - } - } -} From 2d3bb4d5db71cc14e617dec8fa69799552b75975 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Fri, 4 Aug 2023 08:10:08 -0700 Subject: [PATCH 30/68] [SPARK-44672][INFRA] Fix git ignore rules related to Antlr ### What changes were proposed in this pull request? https://github.com/apache/spark/pull/41928 moved antlr4 related files and directories from the `sql/catalyst` module to the `sql/api` module. This pr fix the corresponding git ignore rules to avoid unexpected diffs when executing `git status`. ### Why are the changes needed? Avoid unexpected diffs when executing `git status`. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Manual check `git status` **Before** ``` sql/api/gen/ sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/gen/ ``` **After** no diff. Closes #42342 from LuciferYang/minor-gitignore. Authored-by: yangjie01 Signed-off-by: Dongjoon Hyun --- .gitignore | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 11141961bf805..064b502175b79 100644 --- a/.gitignore +++ b/.gitignore @@ -117,6 +117,6 @@ spark-warehouse/ node_modules # For Antlr -sql/catalyst/gen/ -sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.tokens -sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/gen/ +sql/api/gen/ +sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.tokens +sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/gen/ From 780bae928399947a351dd4b36afcfc7a8be06b13 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Sat, 5 Aug 2023 00:58:16 +0900 Subject: [PATCH 31/68] [SPARK-44671][PYTHON][CONNECT] Retry ExecutePlan in case initial request didn't reach server in Python client ### What changes were proposed in this pull request? The fix for the symmetry to https://github.com/apache/spark/pull/42282. ### Why are the changes needed? See also https://github.com/apache/spark/pull/42282 ### Does this PR introduce _any_ user-facing change? See also https://github.com/apache/spark/pull/42282 ### How was this patch tested? See also https://github.com/apache/spark/pull/42282 Closes #42338 from HyukjinKwon/SPARK-44671. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/connect/client/core.py | 7 ++- python/pyspark/sql/connect/client/reattach.py | 51 ++++++++++++++++--- 2 files changed, 48 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index a82c596555f8a..a7c3a92d3b1dc 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -65,7 +65,10 @@ from pyspark.version import __version__ from pyspark.resource.information import ResourceInformation from pyspark.sql.connect.client.artifact import ArtifactManager -from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator +from pyspark.sql.connect.client.reattach import ( + ExecutePlanResponseReattachableIterator, + RetryException, +) from pyspark.sql.connect.conversion import storage_level_to_proto, proto_to_storage_level import pyspark.sql.connect.proto as pb2 import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib @@ -1549,7 +1552,7 @@ def __exit__( ) -> Optional[bool]: if isinstance(exc_val, BaseException): # Swallow the exception. - if self._can_retry(exc_val): + if self._can_retry(exc_val) or isinstance(exc_val, RetryException): self._retry_state.set_exception(exc_val) return True # Bubble up the exception. diff --git a/python/pyspark/sql/connect/client/reattach.py b/python/pyspark/sql/connect/client/reattach.py index 702107d97f549..70c7d126ff105 100644 --- a/python/pyspark/sql/connect/client/reattach.py +++ b/python/pyspark/sql/connect/client/reattach.py @@ -21,10 +21,13 @@ import warnings import uuid from collections.abc import Generator -from typing import Optional, Dict, Any, Iterator, Iterable, Tuple +from typing import Optional, Dict, Any, Iterator, Iterable, Tuple, Callable, cast from multiprocessing.pool import ThreadPool import os +import grpc +from grpc_status import rpc_status + import pyspark.sql.connect.proto as pb2 import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib @@ -42,15 +45,12 @@ class ExecutePlanResponseReattachableIterator(Generator): Initial iterator is the result of an ExecutePlan on the request, but it can be reattached with ReattachExecute request. ReattachExecute request is provided the responseId of last returned ExecutePlanResponse on the iterator to return a new iterator from server that continues after - that. + that. If the initial ExecutePlan did not even reach the server, and hence reattach fails with + INVALID_HANDLE.OPERATION_NOT_FOUND, we attempt to retry ExecutePlan. In reattachable execute the server does buffer some responses in case the client needs to backtrack. To let server release this buffer sooner, this iterator asynchronously sends ReleaseExecute RPCs that instruct the server to release responses that it already processed. - - Note: If the initial ExecutePlan did not even reach the server and execution didn't start, - the ReattachExecute can still fail with INVALID_HANDLE.OPERATION_NOT_FOUND, failing the whole - operation. """ _release_thread_pool = ThreadPool(os.cpu_count() if os.cpu_count() else 8) @@ -93,6 +93,7 @@ def __init__( # Initial iterator comes from ExecutePlan request. # Note: This is not retried, because no error would ever be thrown here, and GRPC will only # throw error on first self._has_next(). + self._metadata = metadata self._iterator: Iterator[pb2.ExecutePlanResponse] = iter( self._stub.ExecutePlan(self._initial_request, metadata=metadata) ) @@ -139,7 +140,7 @@ def _has_next(self) -> bool: if self._current is None: try: - self._current = next(self._iterator) + self._current = self._call_iter(lambda: next(self._iterator)) except StopIteration: pass @@ -159,7 +160,7 @@ def _has_next(self) -> bool: # shouldn't change assert not self._result_complete try: - self._current = next(self._iterator) + self._current = self._call_iter(lambda: next(self._iterator)) except StopIteration: pass has_next = self._current is not None @@ -226,6 +227,33 @@ def target() -> None: ExecutePlanResponseReattachableIterator._release_thread_pool.apply_async(target) self._result_complete = True + def _call_iter(self, iter_fun: Callable) -> Any: + """ + Call next() on the iterator. If this fails with this operationId not existing + on the server, this means that the initial ExecutePlan request didn't even reach the + server. In that case, attempt to start again with ExecutePlan. + + Called inside retry block, so retryable failure will get handled upstream. + """ + try: + return iter_fun() + except grpc.RpcError as e: + status = rpc_status.from_call(cast(grpc.Call, e)) + if "INVALID_HANDLE.OPERATION_NOT_FOUND" in status.message: + if self._last_returned_response_id is not None: + raise RuntimeError( + "OPERATION_NOT_FOUND on the server but " + "responses were already received from it.", + e, + ) + # Try a new ExecutePlan, and throw upstream for retry. + self._iterator = iter( + self._stub.ExecutePlan(self._initial_request, metadata=self._metadata) + ) + raise RetryException() + else: + raise e + def _create_reattach_execute_request(self) -> pb2.ReattachExecuteRequest: reattach = pb2.ReattachExecuteRequest( session_id=self._initial_request.session_id, @@ -269,3 +297,10 @@ def close(self) -> None: def __del__(self) -> None: return self.close() + + +class RetryException(Exception): + """ + An exception that can be thrown upstream when inside retry and which will be retryable + regardless of policy. + """ From 62415dc59627e1f7b4e3449ae728e93c1fc0b74f Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Fri, 4 Aug 2023 13:15:52 -0700 Subject: [PATCH 32/68] [SPARK-44644][PYTHON] Improve error messages for Python UDTFs with pickling errors ### What changes were proposed in this pull request? This PR improves the error messages when a Python UDTF failed to pickle. ### Why are the changes needed? To make the error message more user-friendly ### Does this PR introduce _any_ user-facing change? Yes, before this PR, when a UDTF fails to pickle, it throws this confusing exception: ``` _pickle.PicklingError: Cannot pickle files that are not opened for reading: w ``` After this PR, the error is more clear: `[UDTF_SERIALIZATION_ERROR] Cannot serialize the UDTF 'TestUDTF': Please check the stack trace and make sure that the function is serializable.` And for spark session access inside a UDTF: `[UDTF_SERIALIZATION_ERROR] it appears that you are attempting to reference SparkSession inside a UDTF. SparkSession can only be used on the driver, not in code that runs on workers. Please remove the reference and try again.` ### How was this patch tested? New UTs. Closes #42309 from allisonwang-db/spark-44644-pickling. Authored-by: allisonwang-db Signed-off-by: Takuya UESHIN --- .../pyspark/cloudpickle/cloudpickle_fast.py | 2 +- python/pyspark/errors/error_classes.py | 5 ++++ python/pyspark/sql/connect/plan.py | 15 +++++++++-- python/pyspark/sql/tests/test_udtf.py | 27 +++++++++++++++++++ python/pyspark/sql/udtf.py | 25 ++++++++++++++++- 5 files changed, 70 insertions(+), 4 deletions(-) diff --git a/python/pyspark/cloudpickle/cloudpickle_fast.py b/python/pyspark/cloudpickle/cloudpickle_fast.py index 63aaffa096b2c..ee1f4b8ee967e 100644 --- a/python/pyspark/cloudpickle/cloudpickle_fast.py +++ b/python/pyspark/cloudpickle/cloudpickle_fast.py @@ -631,7 +631,7 @@ def dump(self, obj): try: return Pickler.dump(self, obj) except RuntimeError as e: - if "recursion" in e.args[0]: + if len(e.args) > 0 and "recursion" in e.args[0]: msg = ( "Could not pickle object as excessively deep recursion " "required." diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index 84448f1507dd8..a534bc6deb41e 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -753,6 +753,11 @@ "Mismatch in return type for the UDTF ''. Expected a 'StructType', but got ''. Please ensure the return type is a correctly formatted StructType." ] }, + "UDTF_SERIALIZATION_ERROR" : { + "message" : [ + "Cannot serialize the UDTF '': " + ] + }, "UNEXPECTED_RESPONSE_FROM_SERVER" : { "message" : [ "Unexpected response from iterator server." diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 58dffd93bf9b5..7da93ef413c20 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -21,6 +21,7 @@ from typing import Any, List, Optional, Type, Sequence, Union, cast, TYPE_CHECKING, Mapping, Dict import functools import json +import pickle from threading import Lock from inspect import signature, isclass @@ -40,7 +41,7 @@ LiteralExpression, ) from pyspark.sql.connect.types import pyspark_types_to_proto_types, UnparsedDataType -from pyspark.errors import PySparkTypeError, PySparkNotImplementedError +from pyspark.errors import PySparkTypeError, PySparkNotImplementedError, PySparkRuntimeError if TYPE_CHECKING: from pyspark.sql.connect._typing import ColumnOrName @@ -2202,7 +2203,17 @@ def to_plan(self, session: "SparkConnectClient") -> proto.PythonUDTF: if self._return_type is not None: udtf.return_type.CopyFrom(pyspark_types_to_proto_types(self._return_type)) udtf.eval_type = self._eval_type - udtf.command = CloudPickleSerializer().dumps(self._func) + try: + udtf.command = CloudPickleSerializer().dumps(self._func) + except pickle.PicklingError: + raise PySparkRuntimeError( + error_class="UDTF_SERIALIZATION_ERROR", + message_parameters={ + "name": self._name, + "message": "Please check the stack trace and " + "make sure the function is serializable.", + }, + ) udtf.python_ver = self._python_ver return udtf diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 26da83980e160..4d36b53799503 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -26,6 +26,7 @@ PythonException, PySparkTypeError, AnalysisException, + PySparkRuntimeError, ) from pyspark.files import SparkFiles from pyspark.rdd import PythonEvalType @@ -740,6 +741,32 @@ def upper(s: str): }, ) + def test_udtf_pickle_error(self): + with tempfile.TemporaryDirectory() as d: + file = os.path.join(d, "file.txt") + file_obj = open(file, "w") + + @udtf(returnType="x: int") + class TestUDTF: + def eval(self): + file_obj + yield 1, + + with self.assertRaisesRegex(PySparkRuntimeError, "UDTF_SERIALIZATION_ERROR"): + TestUDTF().collect() + + def test_udtf_access_spark_session(self): + df = self.spark.range(10) + + @udtf(returnType="x: int") + class TestUDTF: + def eval(self): + df.collect() + yield 1, + + with self.assertRaisesRegex(PySparkRuntimeError, "UDTF_SERIALIZATION_ERROR"): + TestUDTF().collect() + def test_udtf_no_eval(self): with self.assertRaises(PySparkAttributeError) as e: diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py index 74a9084c6cd55..fea0f74c8f2f6 100644 --- a/python/pyspark/sql/udtf.py +++ b/python/pyspark/sql/udtf.py @@ -17,6 +17,7 @@ """ User-defined table function related classes and functions """ +import pickle from dataclasses import dataclass from functools import wraps import inspect @@ -303,7 +304,29 @@ def _create_judtf(self, func: Type) -> JavaObject: spark = SparkSession._getActiveSessionOrCreate() sc = spark.sparkContext - wrapped_func = _wrap_function(sc, func) + try: + wrapped_func = _wrap_function(sc, func) + except pickle.PicklingError as e: + if "CONTEXT_ONLY_VALID_ON_DRIVER" in str(e): + raise PySparkRuntimeError( + error_class="UDTF_SERIALIZATION_ERROR", + message_parameters={ + "name": self._name, + "message": "it appears that you are attempting to reference SparkSession " + "inside a UDTF. SparkSession can only be used on the driver, " + "not in code that runs on workers. Please remove the reference " + "and try again.", + }, + ) from None + raise PySparkRuntimeError( + error_class="UDTF_SERIALIZATION_ERROR", + message_parameters={ + "name": self._name, + "message": "Please check the stack trace and make sure the " + "function is serializable.", + }, + ) + assert sc._jvm is not None if self.returnType is None: judtf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonTableFunction( From 8b53fed7ef0edaaf948ec67413017e60444230fd Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Fri, 4 Aug 2023 16:44:01 -0700 Subject: [PATCH 33/68] [SPARK-44663][PYTHON] Disable arrow optimization by default for Python UDTFs ### What changes were proposed in this pull request? This PR disables arrow optimization by default for Python UDTFs. ### Why are the changes needed? To make Python UDTFs consistent with Python UDFs (arrow optimization is by default disabled). ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New unit tests Closes #42329 from allisonwang-db/spark-44663-disable-arrow. Authored-by: allisonwang-db Signed-off-by: Takuya UESHIN --- python/pyspark/sql/connect/udtf.py | 10 +++++----- python/pyspark/sql/tests/test_udtf.py | 16 ++++++++++++++++ python/pyspark/sql/udtf.py | 10 +++++----- .../org/apache/spark/sql/internal/SQLConf.scala | 2 +- 4 files changed, 27 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql/connect/udtf.py b/python/pyspark/sql/connect/udtf.py index 919994401c802..5a95075a65537 100644 --- a/python/pyspark/sql/connect/udtf.py +++ b/python/pyspark/sql/connect/udtf.py @@ -70,11 +70,11 @@ def _create_py_udtf( else: from pyspark.sql.connect.session import _active_spark_session - arrow_enabled = ( - _active_spark_session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled") == "true" - if _active_spark_session is not None - else True - ) + arrow_enabled = False + if _active_spark_session is not None: + value = _active_spark_session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled") + if isinstance(value, str) and value.lower() == "true": + arrow_enabled = True # Create a regular Python UDTF and check for invalid handler class. regular_udtf = _create_udtf(cls, returnType, name, PythonEvalType.SQL_TABLE_UDF, deterministic) diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 4d36b53799503..9caf267e48df3 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -1723,6 +1723,22 @@ def eval(self, x: str): PythonEvalType.SQL_ARROW_TABLE_UDF, ) + def test_udtf_arrow_sql_conf(self): + class TestUDTF: + def eval(self): + yield 1, + + # We do not use `self.sql_conf` here to test the SQL SET command + # instead of using PySpark's `spark.conf.set`. + old_value = self.spark.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled") + self.spark.sql("SET spark.sql.execution.pythonUDTF.arrow.enabled=False") + self.assertEqual(udtf(TestUDTF, returnType="x: int").evalType, PythonEvalType.SQL_TABLE_UDF) + self.spark.sql("SET spark.sql.execution.pythonUDTF.arrow.enabled=True") + self.assertEqual( + udtf(TestUDTF, returnType="x: int").evalType, PythonEvalType.SQL_ARROW_TABLE_UDF + ) + self.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", old_value) + def test_udtf_eval_returning_non_tuple(self): class TestUDTF: def eval(self, a: int): diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py index fea0f74c8f2f6..027a2646a4657 100644 --- a/python/pyspark/sql/udtf.py +++ b/python/pyspark/sql/udtf.py @@ -106,11 +106,11 @@ def _create_py_udtf( from pyspark.sql import SparkSession session = SparkSession._instantiatedSession - arrow_enabled = ( - session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled") == "true" - if session is not None - else True - ) + arrow_enabled = False + if session is not None: + value = session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled") + if isinstance(value, str) and value.lower() == "true": + arrow_enabled = True # Create a regular Python UDTF and check for invalid handler class. regular_udtf = _create_udtf(cls, returnType, name, PythonEvalType.SQL_TABLE_UDF, deterministic) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ad2d323140a6d..bcf8ce2bc5407 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2942,7 +2942,7 @@ object SQLConf { .doc("Enable Arrow optimization for Python UDTFs.") .version("3.5.0") .booleanConf - .createWithDefault(true) + .createWithDefault(false) val PYTHON_TABLE_UDF_ANALYZER_MEMORY = buildConf("spark.sql.analyzer.pythonUDTF.analyzeInPython.memory") From 8e60a04d19ed7b1d340eb7fb068df365f7969b43 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Sat, 5 Aug 2023 08:05:29 +0800 Subject: [PATCH 34/68] [SPARK-44667][INFRA] Uninstall large ML libraries for non-ML jobs ### What changes were proposed in this pull request? Uninstall large ML libraries for non-ML jobs ### Why are the changes needed? ML is integrating external frameworks: torch, deepspeed (maybe xgboost in future) those libraries are huge, and not needed in other jobs. this PR uninstall torch, which save ~1.3G ![image](https://github.com/apache/spark/assets/7322292/e8181924-ca30-4e1e-8808-659f6a75c1d1) ### Does this PR introduce _any_ user-facing change? no, infra-only ### How was this patch tested? updated CI Closes #42334 from zhengruifeng/infra_uninstall_torch. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- .github/workflows/build_and_test.yml | 14 +++++++++++--- dev/sparktestsupport/modules.py | 18 ++++++++++++++++-- dev/sparktestsupport/utils.py | 23 +++++++++++++---------- 3 files changed, 40 insertions(+), 15 deletions(-) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index ea0c8e1d7fdeb..04585481a9ce6 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -350,9 +350,11 @@ jobs: - >- pyspark-errors - >- - pyspark-sql, pyspark-mllib, pyspark-resource, pyspark-testing + pyspark-sql, pyspark-resource, pyspark-testing - >- - pyspark-core, pyspark-streaming, pyspark-ml + pyspark-core, pyspark-streaming + - >- + pyspark-mllib, pyspark-ml, pyspark-ml-connect - >- pyspark-pandas - >- @@ -411,7 +413,13 @@ jobs: restore-keys: | pyspark-coursier- - name: Free up disk space - run: ./dev/free_disk_space_container + shell: 'script -q -e -c "bash {0}"' + run: | + if [[ "$MODULES_TO_TEST" != *"pyspark-ml"* ]]; then + # uninstall libraries dedicated for ML testing + python3.9 -m pip uninstall -y torch torchvision torcheval torchtnt tensorboard mlflow + fi + ./dev/free_disk_space_container - name: Install Java ${{ matrix.java }} uses: actions/setup-java@v3 with: diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 9e45e0facefc1..b2f978c47ea30 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -817,10 +817,9 @@ def __hash__(self): pyspark_connect = Module( name="pyspark-connect", - dependencies=[pyspark_sql, pyspark_ml, connect], + dependencies=[pyspark_sql, connect], source_file_regexes=[ "python/pyspark/sql/connect", - "python/pyspark/ml/connect", ], python_test_goals=[ # sql doctests @@ -871,6 +870,21 @@ def __hash__(self): "pyspark.sql.tests.connect.test_parity_pandas_udf_scalar", "pyspark.sql.tests.connect.test_parity_pandas_udf_grouped_agg", "pyspark.sql.tests.connect.test_parity_pandas_udf_window", + ], + excluded_python_implementations=[ + "PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and + # they aren't available there + ], +) + + +pyspark_ml_connect = Module( + name="pyspark-ml-connect", + dependencies=[pyspark_connect, pyspark_ml], + source_file_regexes=[ + "python/pyspark/ml/connect", + ], + python_test_goals=[ # ml doctests "pyspark.ml.connect.functions", # ml unittests diff --git a/dev/sparktestsupport/utils.py b/dev/sparktestsupport/utils.py index 816c982bd60e9..e79d864c32095 100755 --- a/dev/sparktestsupport/utils.py +++ b/dev/sparktestsupport/utils.py @@ -112,25 +112,28 @@ def determine_modules_to_test(changed_modules, deduplicated=True): >>> sorted([x.name for x in determine_modules_to_test([modules.sql])]) ... # doctest: +NORMALIZE_WHITESPACE ['avro', 'connect', 'docker-integration-tests', 'examples', 'hive', 'hive-thriftserver', - 'mllib', 'protobuf', 'pyspark-connect', 'pyspark-ml', 'pyspark-mllib', 'pyspark-pandas', - 'pyspark-pandas-connect', 'pyspark-pandas-slow', 'pyspark-pandas-slow-connect', 'pyspark-sql', - 'pyspark-testing', 'repl', 'sparkr', 'sql', 'sql-kafka-0-10'] + 'mllib', 'protobuf', 'pyspark-connect', 'pyspark-ml', 'pyspark-ml-connect', 'pyspark-mllib', + 'pyspark-pandas', 'pyspark-pandas-connect', 'pyspark-pandas-slow', + 'pyspark-pandas-slow-connect', 'pyspark-sql', 'pyspark-testing', 'repl', 'sparkr', 'sql', + 'sql-kafka-0-10'] >>> sorted([x.name for x in determine_modules_to_test( ... [modules.sparkr, modules.sql], deduplicated=False)]) ... # doctest: +NORMALIZE_WHITESPACE ['avro', 'connect', 'docker-integration-tests', 'examples', 'hive', 'hive-thriftserver', - 'mllib', 'protobuf', 'pyspark-connect', 'pyspark-ml', 'pyspark-mllib', 'pyspark-pandas', - 'pyspark-pandas-connect', 'pyspark-pandas-slow', 'pyspark-pandas-slow-connect', 'pyspark-sql', - 'pyspark-testing', 'repl', 'sparkr', 'sql', 'sql-kafka-0-10'] + 'mllib', 'protobuf', 'pyspark-connect', 'pyspark-ml', 'pyspark-ml-connect', 'pyspark-mllib', + 'pyspark-pandas', 'pyspark-pandas-connect', 'pyspark-pandas-slow', + 'pyspark-pandas-slow-connect', 'pyspark-sql', 'pyspark-testing', 'repl', 'sparkr', 'sql', + 'sql-kafka-0-10'] >>> sorted([x.name for x in determine_modules_to_test( ... [modules.sql, modules.core], deduplicated=False)]) ... # doctest: +NORMALIZE_WHITESPACE ['avro', 'catalyst', 'connect', 'core', 'docker-integration-tests', 'examples', 'graphx', 'hive', 'hive-thriftserver', 'mllib', 'mllib-local', 'protobuf', 'pyspark-connect', - 'pyspark-core', 'pyspark-ml', 'pyspark-mllib', 'pyspark-pandas', 'pyspark-pandas-connect', - 'pyspark-pandas-slow', 'pyspark-pandas-slow-connect', 'pyspark-resource', 'pyspark-sql', - 'pyspark-streaming', 'pyspark-testing', 'repl', 'root', 'sparkr', 'sql', 'sql-kafka-0-10', - 'streaming', 'streaming-kafka-0-10', 'streaming-kinesis-asl'] + 'pyspark-core', 'pyspark-ml', 'pyspark-ml-connect', 'pyspark-mllib', 'pyspark-pandas', + 'pyspark-pandas-connect', 'pyspark-pandas-slow', 'pyspark-pandas-slow-connect', + 'pyspark-resource', 'pyspark-sql', 'pyspark-streaming', 'pyspark-testing', 'repl', + 'root', 'sparkr', 'sql', 'sql-kafka-0-10', 'streaming', 'streaming-kafka-0-10', + 'streaming-kinesis-asl'] """ modules_to_test = set() for module in changed_modules: From cf64008fce77b38d1237874b04f5ac124b01b3a8 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Fri, 4 Aug 2023 17:41:27 -0700 Subject: [PATCH 35/68] [SPARK-44433][PYTHON][CONNECT][SS][FOLLOWUP] Set back USE_DAEMON after creating streaming python processes ### What changes were proposed in this pull request? Followup of this comment: https://github.com/apache/spark/pull/42283#discussion_r1283804782 Change back the spark conf after creating streaming python process. ### Why are the changes needed? Bug fix ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Config only change Closes #42341 from WweiL/SPARK-44433-followup-USEDAEMON. Authored-by: Wei Liu Signed-off-by: Takuya UESHIN --- .../api/python/StreamingPythonRunner.scala | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala index f14289f984a2f..a079743c847ae 100644 --- a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala @@ -66,14 +66,19 @@ private[spark] class StreamingPythonRunner( envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString) envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString) - conf.set(PYTHON_USE_DAEMON, false) envVars.put("SPARK_CONNECT_LOCAL_URL", connectUrl) - val (worker, _) = env.createPythonWorker( - pythonExec, workerModule, envVars.asScala.toMap) - pythonWorker = Some(worker) + val prevConf = conf.get(PYTHON_USE_DAEMON) + conf.set(PYTHON_USE_DAEMON, false) + try { + val (worker, _) = env.createPythonWorker( + pythonExec, workerModule, envVars.asScala.toMap) + pythonWorker = Some(worker) + } finally { + conf.set(PYTHON_USE_DAEMON, prevConf) + } - val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) + val stream = new BufferedOutputStream(pythonWorker.get.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) // TODO(SPARK-44461): verify python version @@ -87,7 +92,8 @@ private[spark] class StreamingPythonRunner( dataOut.write(command.toArray) dataOut.flush() - val dataIn = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) + val dataIn = new DataInputStream( + new BufferedInputStream(pythonWorker.get.getInputStream, bufferSize)) val resFromPython = dataIn.readInt() logInfo(s"Runner initialization returned $resFromPython") From 8332f0b6d33ade037e65459aced47e18fb41f76c Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Sun, 6 Aug 2023 01:13:43 +0800 Subject: [PATCH 36/68] [SPARK-44687][BUILD] Fix mima check for Scala 2.13 after SPARK-44198 merged MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? This pr aims add a new `ProblemFilters` to `MimaExcludes.scala` to fix mima check for Scala 2.13 after SPARK-44198 merged. ### Why are the changes needed? Scala 2.13's daily tests have been failing the mima check for several days: - https://github.com/apache/spark/actions/runs/5765663964 image ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Pass GitHub Actions - Manual verification: 1. The mima check was passing before SPARK-44198. ``` // [SPARK-44425][CONNECT] Validate that user provided sessionId is an UUID git reset --hard a3bd477a6d8c317ee1e9a6aae6ebd2ef4fc67cce dev/change-scala-version.sh 2.13 dev/mima -Pscala-2.13 ``` ``` [success] Total time: 129 s (02:09), completed 2023-8-5 14:21:06 ``` 2. The mima check failed after SPARK-44198 was merged ``` // [SPARK-44198][CORE] Support propagation of the log level to the executors git reset --hard 5fc90fbd4e3235fbcf038f4725037321b8234d94 dev/change-scala-version.sh 2.13 dev/mima -Pscala-2.13 ``` ``` [error] spark-core: Failed binary compatibility check against org.apache.spark:spark-core_2.13:3.4.0! Found 1 potential problems (filtered 4013) [error] * the type hierarchy of object org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages#SparkAppConfig is different in current version. Missing types {scala.runtime.AbstractFunction4} [error] filter with: ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages$SparkAppConfig$") [error] java.lang.RuntimeException: Failed binary compatibility check against org.apache.spark:spark-core_2.13:3.4.0! Found 1 potential problems (filtered 4013) [error] at scala.sys.package$.error(package.scala:30) [error] at com.typesafe.tools.mima.plugin.SbtMima$.reportModuleErrors(SbtMima.scala:89) [error] at com.typesafe.tools.mima.plugin.MimaPlugin$.$anonfun$projectSettings$2(MimaPlugin.scala:36) [error] at com.typesafe.tools.mima.plugin.MimaPlugin$.$anonfun$projectSettings$2$adapted(MimaPlugin.scala:26) [error] at scala.collection.Iterator.foreach(Iterator.scala:943) [error] at scala.collection.Iterator.foreach$(Iterator.scala:943) [error] at scala.collection.AbstractIterator.foreach(Iterator.scala:1431) [error] at com.typesafe.tools.mima.plugin.MimaPlugin$.$anonfun$projectSettings$1(MimaPlugin.scala:26) [error] at com.typesafe.tools.mima.plugin.MimaPlugin$.$anonfun$projectSettings$1$adapted(MimaPlugin.scala:25) [error] at scala.Function1.$anonfun$compose$1(Function1.scala:49) [error] at sbt.internal.util.$tilde$greater.$anonfun$$u2219$1(TypeFunctions.scala:63) [error] at sbt.std.Transform$$anon$4.work(Transform.scala:69) [error] at sbt.Execute.$anonfun$submit$2(Execute.scala:283) [error] at sbt.internal.util.ErrorHandling$.wideConvert(ErrorHandling.scala:24) [error] at sbt.Execute.work(Execute.scala:292) [error] at sbt.Execute.$anonfun$submit$1(Execute.scala:283) [error] at sbt.ConcurrentRestrictions$$anon$4.$anonfun$submitValid$1(ConcurrentRestrictions.scala:265) [error] at sbt.CompletionService$$anon$2.call(CompletionService.scala:65) [error] at java.util.concurrent.FutureTask.run(FutureTask.java:266) [error] at java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:511) [error] at java.util.concurrent.FutureTask.run(FutureTask.java:266) [error] at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) [error] at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) [error] at java.lang.Thread.run(Thread.java:750) [error] (core / mimaReportBinaryIssues) Failed binary compatibility check against org.apache.spark:spark-core_2.13:3.4.0! Found 1 potential problems (filtered 4013) [error] Total time: 82 s (01:22), completed 2023-8-5 14:23:49 ``` 3. with this pr, mima check pass ``` gh pr checkout 42358 dev/change-scala-version.sh 2.13 dev/mima -Pscala-2.13 ``` ``` [success] Total time: 157 s (02:37), completed 2023-8-5 14:31:05 ``` Closes #42358 from LuciferYang/SPARK-44687. Authored-by: yangjie01 Signed-off-by: yangjie01 --- project/MimaExcludes.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 14fa43b56725e..d0fc8f2b11655 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -73,7 +73,9 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.AnalysisException$"), // [SPARK-44535][CONNECT][SQL] Move required Streaming API to sql/api ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.GroupStateTimeout"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.OutputMode") + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.OutputMode"), + // [SPARK-44198][CORE] Support propagation of the log level to the executors + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages$SparkAppConfig$") ) // Default exclude rules From d264ee37f316b32bf37fff770b72b5841b84f7cc Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Sun, 6 Aug 2023 17:41:11 +0800 Subject: [PATCH 37/68] [SPARK-44688][INFRA] Add a file existence check before executing `free_disk_space` and `free_disk_space_container` ### What changes were proposed in this pull request? This pr add a file existence check before executing `dev/free_disk_space` and `dev/free_disk_space_container` ### Why are the changes needed? We added `free_disk_space` and `free_disk_space_container` to clean up the disk, but because the daily tests of other branches and the master branch share the yml file, we should check if the file exists before execution, otherwise it will affect the daily tests of other branches. - branch-3.5: https://github.com/apache/spark/actions/runs/5761479443 - branch-3.4: https://github.com/apache/spark/actions/runs/5760423900 - branch-3.3: https://github.com/apache/spark/actions/runs/5759384052 image ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Pass GitHub Action Closes #42359 from LuciferYang/test-free_disk_space-exist. Authored-by: yangjie01 Signed-off-by: yangjie01 --- .github/workflows/build_and_test.yml | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 04585481a9ce6..cd68c0904d9a4 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -241,7 +241,10 @@ jobs: restore-keys: | ${{ matrix.java }}-${{ matrix.hadoop }}-coursier- - name: Free up disk space - run: ./dev/free_disk_space + run: | + if [ -f ./dev/free_disk_space ]; then + ./dev/free_disk_space + fi - name: Install Java ${{ matrix.java }} uses: actions/setup-java@v3 with: @@ -419,7 +422,9 @@ jobs: # uninstall libraries dedicated for ML testing python3.9 -m pip uninstall -y torch torchvision torcheval torchtnt tensorboard mlflow fi - ./dev/free_disk_space_container + if [ -f ./dev/free_disk_space_container ]; then + ./dev/free_disk_space_container + fi - name: Install Java ${{ matrix.java }} uses: actions/setup-java@v3 with: @@ -519,7 +524,10 @@ jobs: restore-keys: | sparkr-coursier- - name: Free up disk space - run: ./dev/free_disk_space_container + run: | + if [ -f ./dev/free_disk_space_container ]; then + ./dev/free_disk_space_container + fi - name: Install Java ${{ inputs.java }} uses: actions/setup-java@v3 with: @@ -629,7 +637,10 @@ jobs: restore-keys: | docs-maven- - name: Free up disk space - run: ./dev/free_disk_space_container + run: | + if [ -f ./dev/free_disk_space_container ]; then + ./dev/free_disk_space_container + fi - name: Install Java 8 uses: actions/setup-java@v3 with: From 41a2a7daeee0a25d39f30364a694becf54ab37e7 Mon Sep 17 00:00:00 2001 From: sychen Date: Sun, 6 Aug 2023 08:24:40 -0500 Subject: [PATCH 38/68] [SPARK-44650][CORE] `spark.executor.defaultJavaOptions` Check illegal java options ### What changes were proposed in this pull request? ### Why are the changes needed? Command ```bash ./bin/spark-shell --conf spark.executor.extraJavaOptions='-Dspark.foo=bar' ``` Error ``` spark.executor.extraJavaOptions is not allowed to set Spark options (was '-Dspark.foo=bar'). Set them directly on a SparkConf or in a properties file when using ./bin/spark-submit. ``` Command ```bash ./bin/spark-shell --conf spark.executor.defaultJavaOptions='-Dspark.foo=bar' ``` Start up normally. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? local test & add UT ``` ./bin/spark-shell --conf spark.executor.defaultJavaOptions='-Dspark.foo=bar' ``` ``` spark.executor.defaultJavaOptions is not allowed to set Spark options (was '-Dspark.foo=bar'). Set them directly on a SparkConf or in a properties file when using ./bin/spark-submit. ``` Closes #42313 from cxzl25/SPARK-44650. Authored-by: sychen Signed-off-by: Sean Owen --- .../scala/org/apache/spark/SparkConf.scala | 25 ++++++++++--------- .../org/apache/spark/SparkConfSuite.scala | 14 +++++++++++ 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 813a14acd19e4..8c054d24b10d7 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -503,8 +503,6 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria logWarning(msg) } - val executorOptsKey = EXECUTOR_JAVA_OPTIONS.key - // Used by Yarn in 1.1 and before sys.props.get("spark.driver.libraryPath").foreach { value => val warning = @@ -518,16 +516,19 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria } // Validate spark.executor.extraJavaOptions - getOption(executorOptsKey).foreach { javaOpts => - if (javaOpts.contains("-Dspark")) { - val msg = s"$executorOptsKey is not allowed to set Spark options (was '$javaOpts'). " + - "Set them directly on a SparkConf or in a properties file when using ./bin/spark-submit." - throw new Exception(msg) - } - if (javaOpts.contains("-Xmx")) { - val msg = s"$executorOptsKey is not allowed to specify max heap memory settings " + - s"(was '$javaOpts'). Use spark.executor.memory instead." - throw new Exception(msg) + Seq(EXECUTOR_JAVA_OPTIONS.key, "spark.executor.defaultJavaOptions").foreach { executorOptsKey => + getOption(executorOptsKey).foreach { javaOpts => + if (javaOpts.contains("-Dspark")) { + val msg = s"$executorOptsKey is not allowed to set Spark options (was '$javaOpts'). " + + "Set them directly on a SparkConf or in a properties file " + + "when using ./bin/spark-submit." + throw new Exception(msg) + } + if (javaOpts.contains("-Xmx")) { + val msg = s"$executorOptsKey is not allowed to specify max heap memory settings " + + s"(was '$javaOpts'). Use spark.executor.memory instead." + throw new Exception(msg) + } } } diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 74fd78162218b..75e22e1418b4a 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -498,6 +498,20 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst } } } + + test("SPARK-44650: spark.executor.defaultJavaOptions Check illegal java options") { + val conf = new SparkConf() + conf.validateSettings() + conf.set(EXECUTOR_JAVA_OPTIONS.key, "-Dspark.foo=bar") + intercept[Exception] { + conf.validateSettings() + } + conf.remove(EXECUTOR_JAVA_OPTIONS.key) + conf.set("spark.executor.defaultJavaOptions", "-Dspark.foo=bar") + intercept[Exception] { + conf.validateSettings() + } + } } class Class1 {} From 74ae1e3434c345ad036131bd9e67687554515e68 Mon Sep 17 00:00:00 2001 From: TongWei1105 Date: Sun, 6 Aug 2023 23:21:21 +0800 Subject: [PATCH 39/68] [SPARK-42500][SQL] ConstantPropagation support more case ### What changes were proposed in this pull request? This PR enhances ConstantPropagation to support more cases. Propagated through other binary comparisons. Propagated across equality comparisons. This can be further optimized to false. ### Why are the changes needed? Improve query performance. [Denodo](https://community.denodo.com/docs/html/browse/latest/en/vdp/administration/optimizing_queries/automatic_simplification_of_queries/removing_redundant_branches_of_queries_partitioned_unions) also has a similar optimization. For example: ``` CREATE TABLE t1(a int, b int) using parquet; CREATE TABLE t2(x int, y int) using parquet; CREATE TEMP VIEW v1 AS SELECT * FROM t1 JOIN t2 WHERE a = x AND a = 0 UNION ALL SELECT * FROM t1 JOIN t2 WHERE a = x AND (a IS NULL OR a <> 0); SELECT * FROM v1 WHERE x > 1; ``` Before this PR: ``` == Optimized Logical Plan == Union false, false :- Project [a#0 AS a#12, b#1 AS b#13, x#2 AS x#14, y#3 AS y#15] : +- Join Inner : :- Filter (isnotnull(a#0) AND (a#0 = 0)) : : +- Relation spark_catalog.default.t1[a#0,b#1] parquet : +- Filter (isnotnull(x#2) AND ((0 = x#2) AND (x#2 > 1))) : +- Relation spark_catalog.default.t2[x#2,y#3] parquet +- Join Inner, (a#16 = x#18) :- Filter ((isnull(a#16) OR NOT (a#16 = 0)) AND ((a#16 > 1) AND isnotnull(a#16))) : +- Relation spark_catalog.default.t1[a#16,b#17] parquet +- Filter ((isnotnull(x#18) AND (x#18 > 1)) AND (isnull(x#18) OR NOT (x#18 = 0))) +- Relation spark_catalog.default.t2[x#18,y#19] parquet ``` After this PR: ``` == Optimized Logical Plan == Join Inner, (a#16 = x#18) :- Filter ((isnull(a#16) OR NOT (a#16 = 0)) AND ((a#16 > 1) AND isnotnull(a#16))) : +- Relation spark_catalog.default.t1[a#16,b#17] parquet +- Filter ((isnotnull(x#18) AND (x#18 > 1)) AND (isnull(x#18) OR NOT (x#18 = 0))) +- Relation spark_catalog.default.t2[x#18,y#19] parquet ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit test. Closes #42038 from TongWei1105/SPARK-42500. Authored-by: TongWei1105 Signed-off-by: Yuming Wang --- .../sql/catalyst/optimizer/expressions.scala | 37 +++++++++---------- .../optimizer/ConstantPropagationSuite.scala | 32 +++++++++++++++- 2 files changed, 47 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 8cb560199c069..7b44539929c84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -122,8 +122,6 @@ object ConstantPropagation extends Rule[LogicalPlan] { } } - type EqualityPredicates = Seq[((AttributeReference, Literal), BinaryComparison)] - /** * Traverse a condition as a tree and replace attributes with constant values. * - On matching [[And]], recursively traverse each children and get propagated mappings. @@ -140,23 +138,23 @@ object ConstantPropagation extends Rule[LogicalPlan] { * resulted false * @return A tuple including: * 1. Option[Expression]: optional changed condition after traversal - * 2. EqualityPredicates: propagated mapping of attribute => constant + * 2. AttributeMap: propagated mapping of attribute => constant */ private def traverse(condition: Expression, replaceChildren: Boolean, nullIsFalse: Boolean) - : (Option[Expression], EqualityPredicates) = + : (Option[Expression], AttributeMap[(Literal, BinaryComparison)]) = condition match { case e @ EqualTo(left: AttributeReference, right: Literal) if safeToReplace(left, nullIsFalse) => - (None, Seq(((left, right), e))) + (None, AttributeMap(Map(left -> (right, e)))) case e @ EqualTo(left: Literal, right: AttributeReference) if safeToReplace(right, nullIsFalse) => - (None, Seq(((right, left), e))) + (None, AttributeMap(Map(right -> (left, e)))) case e @ EqualNullSafe(left: AttributeReference, right: Literal) if safeToReplace(left, nullIsFalse) => - (None, Seq(((left, right), e))) + (None, AttributeMap(Map(left -> (right, e)))) case e @ EqualNullSafe(left: Literal, right: AttributeReference) if safeToReplace(right, nullIsFalse) => - (None, Seq(((right, left), e))) + (None, AttributeMap(Map(right -> (left, e)))) case a: And => val (newLeft, equalityPredicatesLeft) = traverse(a.left, replaceChildren = false, nullIsFalse) @@ -183,12 +181,12 @@ object ConstantPropagation extends Rule[LogicalPlan] { } else { None } - (newSelf, Seq.empty) + (newSelf, AttributeMap.empty) case n: Not => // Ignore the EqualityPredicates from children since they are only propagated through And. val (newChild, _) = traverse(n.child, replaceChildren = true, nullIsFalse = false) - (newChild.map(Not), Seq.empty) - case _ => (None, Seq.empty) + (newChild.map(Not), AttributeMap.empty) + case _ => (None, AttributeMap.empty) } // We need to take into account if an attribute is nullable and the context of the conjunctive @@ -199,16 +197,15 @@ object ConstantPropagation extends Rule[LogicalPlan] { private def safeToReplace(ar: AttributeReference, nullIsFalse: Boolean) = !ar.nullable || nullIsFalse - private def replaceConstants(condition: Expression, equalityPredicates: EqualityPredicates) - : Expression = { - val constantsMap = AttributeMap(equalityPredicates.map(_._1)) - val predicates = equalityPredicates.map(_._2).toSet - def replaceConstants0(expression: Expression) = expression transform { - case a: AttributeReference => constantsMap.getOrElse(a, a) - } + private def replaceConstants( + condition: Expression, + equalityPredicates: AttributeMap[(Literal, BinaryComparison)]): Expression = { + val constantsMap = AttributeMap(equalityPredicates.map { case (attr, (lit, _)) => attr -> lit }) + val predicates = equalityPredicates.values.map(_._2).toSet condition transform { - case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants0(e) - case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants0(e) + case b: BinaryComparison if !predicates.contains(b) => b transform { + case a: AttributeReference => constantsMap.getOrElse(a, a) + } } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala index f5f1455f94611..106af71a9d653 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala @@ -159,8 +159,9 @@ class ConstantPropagationSuite extends PlanTest { columnA === Literal(1) && columnA === Literal(2) && columnB === Add(columnA, Literal(3))) val correctAnswer = testRelation - .select(columnA) - .where(columnA === Literal(1) && columnA === Literal(2) && columnB === Literal(5)).analyze + .select(columnA, columnB) + .where(Literal.FalseLiteral) + .select(columnA).analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } @@ -186,4 +187,31 @@ class ConstantPropagationSuite extends PlanTest { .analyze comparePlans(Optimize.execute(query2), correctAnswer2) } + + test("SPARK-42500: ConstantPropagation supports more cases") { + comparePlans( + Optimize.execute(testRelation.where(columnA === 1 && columnB > columnA + 2).analyze), + testRelation.where(columnA === 1 && columnB > 3).analyze) + + comparePlans( + Optimize.execute(testRelation.where(columnA === 1 && columnA === 2).analyze), + testRelation.where(Literal.FalseLiteral).analyze) + + comparePlans( + Optimize.execute(testRelation.where(columnA === 1 && columnA === columnA + 2).analyze), + testRelation.where(Literal.FalseLiteral).analyze) + + comparePlans( + Optimize.execute( + testRelation.where((columnA === 1 || columnB === 2) && columnB === 1).analyze), + testRelation.where(columnA === 1 && columnB === 1).analyze) + + comparePlans( + Optimize.execute(testRelation.where(columnA === 1 && columnA === 1).analyze), + testRelation.where(columnA === 1).analyze) + + comparePlans( + Optimize.execute(testRelation.where(Not(columnA === 1 && columnA === columnA + 2)).analyze), + testRelation.where(Not(columnA === 1) || Not(columnA === columnA + 2)).analyze) + } } From d6998979427b6ad3a0f16d6966b3927d40440a60 Mon Sep 17 00:00:00 2001 From: Giambattista Bloisi Date: Sun, 6 Aug 2023 21:47:57 +0200 Subject: [PATCH 40/68] [SPARK-44634][SQL] Encoders.bean does no longer support nested beans with type arguments ### What changes were proposed in this pull request? This PR fixes a regression introduced in Spark 3.4.x where Encoders.bean is no longer able to process nested beans having type arguments. For example: ``` class A { T value; // value getter and setter } class B { A stringHolder; // stringHolder getter and setter } Encoders.bean(B.class); // throws "SparkUnsupportedOperationException: [ENCODER_NOT_FOUND]..." ``` ### Why are the changes needed? JavaTypeInference.encoderFor main match does not manage ParameterizedType and TypeVariable cases. I think this is a regression introduced after getting rid of usage of guava TypeToken: [SPARK-42093 SQL Move JavaTypeInference to AgnosticEncoders](https://github.com/apache/spark/commit/18672003513d5a4aa610b6b94dbbc15c33185d3#diff-1191737b908340a2f4c22b71b1c40ebaa0da9d8b40c958089c346a3bda26943b) hvanhovell cloud-fan In this PR I'm leveraging commons lang3 TypeUtils functionalities to solve ParameterizedType type arguments for classes ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests have been extended to check correct encoding of a nested bean having type arguments. Closes #42327 from gbloisi-openaire/spark-44634. Authored-by: Giambattista Bloisi Signed-off-by: Herman van Hovell --- .../sql/catalyst/JavaTypeInference.scala | 84 +++++-------------- .../sql/catalyst/JavaBeanWithGenerics.java | 41 +++++++++ .../sql/catalyst/JavaTypeInferenceSuite.scala | 4 + 3 files changed, 64 insertions(+), 65 deletions(-) create mode 100644 sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/JavaBeanWithGenerics.java diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index f352d28a7b501..3d536b735db59 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -18,12 +18,14 @@ package org.apache.spark.sql.catalyst import java.beans.{Introspector, PropertyDescriptor} import java.lang.reflect.{ParameterizedType, Type, TypeVariable} -import java.util.{ArrayDeque, List => JList, Map => JMap} +import java.util.{List => JList, Map => JMap} import javax.annotation.Nonnull -import scala.annotation.tailrec +import scala.collection.JavaConverters._ import scala.reflect.ClassTag +import org.apache.commons.lang3.reflect.{TypeUtils => JavaTypeUtils} + import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, DayTimeIntervalEncoder, DEFAULT_JAVA_DECIMAL_ENCODER, EncoderField, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaEnumEncoder, LocalDateTimeEncoder, MapEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, STRICT_DATE_ENCODER, STRICT_INSTANT_ENCODER, STRICT_LOCAL_DATE_ENCODER, STRICT_TIMESTAMP_ENCODER, StringEncoder, UDTEncoder, YearMonthIntervalEncoder} import org.apache.spark.sql.errors.ExecutionErrors @@ -57,7 +59,8 @@ object JavaTypeInference { encoderFor(beanType, Set.empty).asInstanceOf[AgnosticEncoder[T]] } - private def encoderFor(t: Type, seenTypeSet: Set[Class[_]]): AgnosticEncoder[_] = t match { + private def encoderFor(t: Type, seenTypeSet: Set[Class[_]], + typeVariables: Map[TypeVariable[_], Type] = Map.empty): AgnosticEncoder[_] = t match { case c: Class[_] if c == java.lang.Boolean.TYPE => PrimitiveBooleanEncoder case c: Class[_] if c == java.lang.Byte.TYPE => PrimitiveByteEncoder @@ -101,18 +104,24 @@ object JavaTypeInference { UDTEncoder(udt, udt.getClass) case c: Class[_] if c.isArray => - val elementEncoder = encoderFor(c.getComponentType, seenTypeSet) + val elementEncoder = encoderFor(c.getComponentType, seenTypeSet, typeVariables) ArrayEncoder(elementEncoder, elementEncoder.nullable) - case ImplementsList(c, Array(elementCls)) => - val element = encoderFor(elementCls, seenTypeSet) + case c: Class[_] if classOf[JList[_]].isAssignableFrom(c) => + val element = encoderFor(c.getTypeParameters.array(0), seenTypeSet, typeVariables) IterableEncoder(ClassTag(c), element, element.nullable, lenientSerialization = false) - case ImplementsMap(c, Array(keyCls, valueCls)) => - val keyEncoder = encoderFor(keyCls, seenTypeSet) - val valueEncoder = encoderFor(valueCls, seenTypeSet) + case c: Class[_] if classOf[JMap[_, _]].isAssignableFrom(c) => + val keyEncoder = encoderFor(c.getTypeParameters.array(0), seenTypeSet, typeVariables) + val valueEncoder = encoderFor(c.getTypeParameters.array(1), seenTypeSet, typeVariables) MapEncoder(ClassTag(c), keyEncoder, valueEncoder, valueEncoder.nullable) + case tv: TypeVariable[_] => + encoderFor(typeVariables(tv), seenTypeSet, typeVariables) + + case pt: ParameterizedType => + encoderFor(pt.getRawType, seenTypeSet, JavaTypeUtils.getTypeArguments(pt).asScala.toMap) + case c: Class[_] => if (seenTypeSet.contains(c)) { throw ExecutionErrors.cannotHaveCircularReferencesInBeanClassError(c) @@ -124,7 +133,7 @@ object JavaTypeInference { // Note that the fields are ordered by name. val fields = properties.map { property => val readMethod = property.getReadMethod - val encoder = encoderFor(readMethod.getGenericReturnType, seenTypeSet + c) + val encoder = encoderFor(readMethod.getGenericReturnType, seenTypeSet + c, typeVariables) // The existence of `javax.annotation.Nonnull`, means this field is not nullable. val hasNonNull = readMethod.isAnnotationPresent(classOf[Nonnull]) EncoderField( @@ -147,59 +156,4 @@ object JavaTypeInference { .filterNot(_.getName == "declaringClass") .filter(_.getReadMethod != null) } - - private class ImplementsGenericInterface(interface: Class[_]) { - assert(interface.isInterface) - assert(interface.getTypeParameters.nonEmpty) - - def unapply(t: Type): Option[(Class[_], Array[Type])] = implementsInterface(t).map { cls => - cls -> findTypeArgumentsForInterface(t) - } - - @tailrec - private def implementsInterface(t: Type): Option[Class[_]] = t match { - case pt: ParameterizedType => implementsInterface(pt.getRawType) - case c: Class[_] if interface.isAssignableFrom(c) => Option(c) - case _ => None - } - - private def findTypeArgumentsForInterface(t: Type): Array[Type] = { - val queue = new ArrayDeque[(Type, Map[Any, Type])] - queue.add(t -> Map.empty) - while (!queue.isEmpty) { - queue.poll() match { - case (pt: ParameterizedType, bindings) => - // translate mappings... - val mappedTypeArguments = pt.getActualTypeArguments.map { - case v: TypeVariable[_] => bindings(v.getName) - case v => v - } - if (pt.getRawType == interface) { - return mappedTypeArguments - } else { - val mappedTypeArgumentMap = mappedTypeArguments - .zipWithIndex.map(_.swap) - .toMap[Any, Type] - queue.add(pt.getRawType -> mappedTypeArgumentMap) - } - case (c: Class[_], indexedBindings) => - val namedBindings = c.getTypeParameters.zipWithIndex.map { - case (parameter, index) => - parameter.getName -> indexedBindings(index) - }.toMap[Any, Type] - val superClass = c.getGenericSuperclass - if (superClass != null) { - queue.add(superClass -> namedBindings) - } - c.getGenericInterfaces.foreach { iface => - queue.add(iface -> namedBindings) - } - } - } - throw ExecutionErrors.unreachableError() - } - } - - private object ImplementsList extends ImplementsGenericInterface(classOf[JList[_]]) - private object ImplementsMap extends ImplementsGenericInterface(classOf[JMap[_, _]]) } diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/JavaBeanWithGenerics.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/JavaBeanWithGenerics.java new file mode 100644 index 0000000000000..b84a3122cf84c --- /dev/null +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/JavaBeanWithGenerics.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst; + +class JavaBeanWithGenerics { + private A attribute; + + private T value; + + public A getAttribute() { + return attribute; + } + + public void setAttribute(A attribute) { + this.attribute = attribute; + } + + public T getValue() { + return value; + } + + public void setValue(T value) { + this.value = value; + } +} + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala index 35f5bf739bfce..6439997609766 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala @@ -66,6 +66,7 @@ class LeafBean { @BeanProperty var period: java.time.Period = _ @BeanProperty var enum: java.time.Month = _ @BeanProperty val readOnlyString = "read-only" + @BeanProperty var genericNestedBean: JavaBeanWithGenerics[String, String] = _ var nonNullString: String = "value" @javax.annotation.Nonnull @@ -184,6 +185,9 @@ class JavaTypeInferenceSuite extends SparkFunSuite { encoderField("date", STRICT_DATE_ENCODER), encoderField("duration", DayTimeIntervalEncoder), encoderField("enum", JavaEnumEncoder(classTag[java.time.Month])), + encoderField("genericNestedBean", JavaBeanEncoder( + ClassTag(classOf[JavaBeanWithGenerics[String, String]]), + Seq(encoderField("attribute", StringEncoder), encoderField("value", StringEncoder)))), encoderField("instant", STRICT_INSTANT_ENCODER), encoderField("localDate", STRICT_LOCAL_DATE_ENCODER), encoderField("localDateTime", LocalDateTimeEncoder), From a640373fff38f5c594e4e5c30587bcfe823dee1d Mon Sep 17 00:00:00 2001 From: Amanda Liu Date: Mon, 7 Aug 2023 08:54:52 +0900 Subject: [PATCH 41/68] [SPARK-44629][PYTHON][DOCS] Publish PySpark Test Guidelines webpage ### What changes were proposed in this pull request? This PR adds a webpage to the Spark docs website, https://spark.apache.org/docs, to outline PySpark testing best practices. ### Why are the changes needed? The changes are needed to provide PySpark end users with a guideline for how to use PySpark utils (introduced in SPARK-44629) to test PySpark code. ### Does this PR introduce _any_ user-facing change? Yes, the PR publishes a webpage on the Spark website. ### How was this patch tested? Existing tests Closes #42284 from asl3/testing-guidelines. Authored-by: Amanda Liu Signed-off-by: Hyukjin Kwon --- python/docs/source/getting_started/index.rst | 1 + .../getting_started/testing_pyspark.ipynb | 485 ++++++++++++++++++ 2 files changed, 486 insertions(+) create mode 100644 python/docs/source/getting_started/testing_pyspark.ipynb diff --git a/python/docs/source/getting_started/index.rst b/python/docs/source/getting_started/index.rst index 3c1c7d80863ce..5f6d306651b92 100644 --- a/python/docs/source/getting_started/index.rst +++ b/python/docs/source/getting_started/index.rst @@ -40,3 +40,4 @@ The list below is the contents of this quickstart page: quickstart_df quickstart_connect quickstart_ps + testing_pyspark diff --git a/python/docs/source/getting_started/testing_pyspark.ipynb b/python/docs/source/getting_started/testing_pyspark.ipynb new file mode 100644 index 0000000000000..268ace04376ba --- /dev/null +++ b/python/docs/source/getting_started/testing_pyspark.ipynb @@ -0,0 +1,485 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "4ee2125b-f889-47e6-9c3d-8bd63a253683", + "metadata": {}, + "source": [ + "# Testing PySpark\n", + "\n", + "This guide is a reference for writing robust tests for PySpark code.\n", + "\n", + "To view the docs for PySpark test utils, see here. To see the code for PySpark built-in test utils, check out the Spark repository here. To see the JIRA board tickets for the PySpark test framework, see here." + ] + }, + { + "cell_type": "markdown", + "id": "0e8ee4b6-9544-45e1-8a91-e71ed8ef8b9d", + "metadata": {}, + "source": [ + "## Build a PySpark Application\n", + "Here is an example for how to start a PySpark application. Feel free to skip to the next section, “Testing your PySpark Application,” if you already have an application you’re ready to test.\n", + "\n", + "First, start your Spark Session." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "9af4a35b-17e8-4e45-816b-34c14c5902f7", + "metadata": {}, + "outputs": [], + "source": [ + "from pyspark.sql import SparkSession \n", + "from pyspark.sql.functions import col \n", + "\n", + "# Create a SparkSession \n", + "spark = SparkSession.builder.appName(\"Testing PySpark Example\").getOrCreate() " + ] + }, + { + "cell_type": "markdown", + "id": "4a4c6efe-91f5-4e18-b4b2-b0401c2368e4", + "metadata": {}, + "source": [ + "Next, create a DataFrame." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "3b483dd8-3a76-41c6-9206-301d7ef314d6", + "metadata": {}, + "outputs": [], + "source": [ + "sample_data = [{\"name\": \"John D.\", \"age\": 30}, \n", + " {\"name\": \"Alice G.\", \"age\": 25}, \n", + " {\"name\": \"Bob T.\", \"age\": 35}, \n", + " {\"name\": \"Eve A.\", \"age\": 28}] \n", + "\n", + "df = spark.createDataFrame(sample_data)" + ] + }, + { + "cell_type": "markdown", + "id": "e0f44333-0e08-470b-9fa2-38f59e3dbd63", + "metadata": {}, + "source": [ + "Now, let’s define and apply a transformation function to our DataFrame." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a6c0b766-af5f-4e1d-acf8-887d7cf0b0b2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+---+--------+\n", + "|age| name|\n", + "+---+--------+\n", + "| 30| John D.|\n", + "| 25|Alice G.|\n", + "| 35| Bob T.|\n", + "| 28| Eve A.|\n", + "+---+--------+\n", + "\n" + ] + } + ], + "source": [ + "from pyspark.sql.functions import col, regexp_replace\n", + "\n", + "# Remove additional spaces in name\n", + "def remove_extra_spaces(df, column_name):\n", + " # Remove extra spaces from the specified column\n", + " df_transformed = df.withColumn(column_name, regexp_replace(col(column_name), \"\\\\s+\", \" \"))\n", + " \n", + " return df_transformed\n", + "\n", + "transformed_df = remove_extra_spaces(df, \"name\")\n", + "\n", + "transformed_df.show()" + ] + }, + { + "cell_type": "markdown", + "id": "530beaa6-aabf-43a1-ad2b-361f267e9608", + "metadata": {}, + "source": [ + "## Testing your PySpark Application\n", + "Now let’s test our PySpark transformation function. \n", + "\n", + "One option is to simply eyeball the resulting DataFrame. However, this can be impractical for large DataFrame or input sizes.\n", + "\n", + "A better way is to write tests. Here are some examples of how we can test our code. The examples below apply for Spark 3.5 and above versions.\n", + "\n", + "Note that these examples are not exhaustive, as there are many other test framework alternatives which you can use instead of `unittest` or `pytest`. The built-in PySpark testing util functions are standalone, meaning they can be compatible with any test framework or CI test pipeline.\n" + ] + }, + { + "cell_type": "markdown", + "id": "d84a9fc1-9768-4af4-bfbf-e832f23334dc", + "metadata": {}, + "source": [ + "### Option 1: Using Only PySpark Built-in Test Utility Functions\n", + "\n", + "For simple ad-hoc validation cases, PySpark testing utils like `assertDataFrameEqual` and `assertSchemaEqual` can be used in a standalone context.\n", + "You could easily test PySpark code in a notebook session. For example, say you want to assert equality between two DataFrames:\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "8e533732-ee40-4cd0-9669-8eb92973908a", + "metadata": {}, + "outputs": [], + "source": [ + "import pyspark.testing\n", + "from pyspark.testing.utils import assertDataFrameEqual\n", + "\n", + "# Example 1\n", + "df1 = spark.createDataFrame(data=[(\"1\", 1000), (\"2\", 3000)], schema=[\"id\", \"amount\"])\n", + "df2 = spark.createDataFrame(data=[(\"1\", 1000), (\"2\", 3000)], schema=[\"id\", \"amount\"])\n", + "assertDataFrameEqual(df1, df2) # pass, DataFrames are identical" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "2d77a6be-1e50-4c1a-8a44-85cf7dcec3f3", + "metadata": {}, + "outputs": [], + "source": [ + "# Example 2\n", + "df1 = spark.createDataFrame(data=[(\"1\", 0.1), (\"2\", 3.23)], schema=[\"id\", \"amount\"])\n", + "df2 = spark.createDataFrame(data=[(\"1\", 0.109), (\"2\", 3.23)], schema=[\"id\", \"amount\"])\n", + "assertDataFrameEqual(df1, df2, rtol=1e-1) # pass, DataFrames are approx equal by rtol" + ] + }, + { + "cell_type": "markdown", + "id": "76ade5f2-4a1f-4601-9d2a-80da9da950ff", + "metadata": {}, + "source": [ + "You can also simply compare two DataFrame schemas:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "74393af5-40fb-4d04-87cb-265971ffe6d0", + "metadata": {}, + "outputs": [], + "source": [ + "from pyspark.testing.utils import assertSchemaEqual\n", + "from pyspark.sql.types import StructType, StructField, ArrayType, DoubleType\n", + "\n", + "s1 = StructType([StructField(\"names\", ArrayType(DoubleType(), True), True)])\n", + "s2 = StructType([StructField(\"names\", ArrayType(DoubleType(), True), True)])\n", + "\n", + "assertSchemaEqual(s1, s2) # pass, schemas are identical" + ] + }, + { + "cell_type": "markdown", + "id": "c67be105-f6b1-4083-ad11-9e819331eae8", + "metadata": {}, + "source": [ + "### Option 2: Using [Unit Test](https://docs.python.org/3/library/unittest.html)\n", + "For more complex testing scenarios, you may want to use a testing framework.\n", + "\n", + "One of the most popular testing framework options is unit tests. Let’s walk through how you can use the built-in Python `unittest` library to write PySpark tests. For more information about the `unittest` library, see here: https://docs.python.org/3/library/unittest.html. \n", + "\n", + "First, you will need a Spark session. You can use the `@classmethod` decorator from the `unittest` package to take care of setting up and tearing down a Spark session." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "54093761-0b49-4aee-baec-2d29bcf13f9f", + "metadata": {}, + "outputs": [], + "source": [ + "import unittest\n", + "\n", + "class PySparkTestCase(unittest.TestCase):\n", + " @classmethod\n", + " def setUpClass(cls):\n", + " cls.spark = SparkSession.builder.appName(\"Testing PySpark Example\").getOrCreate() \n", + "\n", + " \n", + " @classmethod\n", + " def tearDownClass(cls):\n", + " cls.spark.stop()" + ] + }, + { + "cell_type": "markdown", + "id": "3de27500-8526-412e-bf09-6927a760c5d7", + "metadata": {}, + "source": [ + "Now let’s write a `unittest` class." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "34feb5e1-944f-4f6b-9c5f-3b0bf68c7d05", + "metadata": {}, + "outputs": [], + "source": [ + "from pyspark.testing.utils import assertDataFrameEqual\n", + "\n", + "class TestTranformation(PySparkTestCase):\n", + " def test_single_space(self):\n", + " sample_data = [{\"name\": \"John D.\", \"age\": 30}, \n", + " {\"name\": \"Alice G.\", \"age\": 25}, \n", + " {\"name\": \"Bob T.\", \"age\": 35}, \n", + " {\"name\": \"Eve A.\", \"age\": 28}] \n", + " \n", + " # Create a Spark DataFrame\n", + " original_df = spark.createDataFrame(sample_data)\n", + " \n", + " # Apply the transformation function from before\n", + " transformed_df = remove_extra_spaces(original_df, \"name\")\n", + " \n", + " expected_data = [{\"name\": \"John D.\", \"age\": 30}, \n", + " {\"name\": \"Alice G.\", \"age\": 25}, \n", + " {\"name\": \"Bob T.\", \"age\": 35}, \n", + " {\"name\": \"Eve A.\", \"age\": 28}]\n", + " \n", + " expected_df = spark.createDataFrame(expected_data)\n", + " \n", + " assertDataFrameEqual(transformed_df, expected_df)\n" + ] + }, + { + "cell_type": "markdown", + "id": "319a690f-71bd-4886-bd3a-424e866525c2", + "metadata": {}, + "source": [ + "When run, `unittest` will pick up all functions with a name beginning with “test.”" + ] + }, + { + "cell_type": "markdown", + "id": "7d79e53d-cc1e-4fdf-a069-478337bed83d", + "metadata": {}, + "source": [ + "### Option 3: Using [Pytest](https://docs.pytest.org/en/7.1.x/contents.html)\n", + "\n", + "We can also write our tests with `pytest`, which is one of the most popular Python testing frameworks. For more information about `pytest`, see the docs here: https://docs.pytest.org/en/7.1.x/contents.html.\n", + "\n", + "Using a `pytest` fixture allows us to share a spark session across tests, tearing it down when the tests are complete." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "60a4f304-1911-4b4d-8ed9-00ecc8b0890b", + "metadata": {}, + "outputs": [], + "source": [ + "import pytest\n", + "\n", + "@pytest.fixture\n", + "def spark_fixture():\n", + " spark = SparkSession.builder.appName(\"Testing PySpark Example\").getOrCreate()\n", + " yield spark" + ] + }, + { + "cell_type": "markdown", + "id": "fcb4e26a-9bfc-48a5-8aca-538697d66642", + "metadata": {}, + "source": [ + "We can then define our tests like this:" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "fa5db3a1-7305-44b7-ab84-f5ed55fd2ba9", + "metadata": {}, + "outputs": [], + "source": [ + "import pytest\n", + "from pyspark.testing.utils import assertDataFrameEqual\n", + "\n", + "def test_single_space(spark_fixture):\n", + " sample_data = [{\"name\": \"John D.\", \"age\": 30}, \n", + " {\"name\": \"Alice G.\", \"age\": 25}, \n", + " {\"name\": \"Bob T.\", \"age\": 35}, \n", + " {\"name\": \"Eve A.\", \"age\": 28}] \n", + " \n", + " # Create a Spark DataFrame\n", + " original_df = spark.createDataFrame(sample_data)\n", + " \n", + " # Apply the transformation function from before\n", + " transformed_df = remove_extra_spaces(original_df, \"name\")\n", + " \n", + " expected_data = [{\"name\": \"John D.\", \"age\": 30}, \n", + " {\"name\": \"Alice G.\", \"age\": 25}, \n", + " {\"name\": \"Bob T.\", \"age\": 35}, \n", + " {\"name\": \"Eve A.\", \"age\": 28}]\n", + " \n", + " expected_df = spark.createDataFrame(expected_data)\n", + "\n", + " assertDataFrameEqual(transformed_df, expected_df)" + ] + }, + { + "cell_type": "markdown", + "id": "0fc3f394-3260-4e42-82cf-1a7edc859151", + "metadata": {}, + "source": [ + "When you run your test file with the `pytest` command, it will pick up all functions that have their name beginning with “test.”" + ] + }, + { + "cell_type": "markdown", + "id": "d8f50eee-5d0b-4719-b505-1b3ff05c16e8", + "metadata": {}, + "source": [ + "## Putting It All Together!\n", + "\n", + "Let’s see all the steps together, in a Unit Test example." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "a2ea9dec-0ac0-4c23-8770-d6cc226d2e97", + "metadata": {}, + "outputs": [], + "source": [ + "# pkg/etl.py\n", + "import unittest\n", + "\n", + "from pyspark.sql import SparkSession \n", + "from pyspark.sql.functions import col\n", + "from pyspark.sql.functions import regexp_replace\n", + "from pyspark.testing.utils import assertDataFrameEqual\n", + "\n", + "# Create a SparkSession \n", + "spark = SparkSession.builder.appName(\"Sample PySpark ETL\").getOrCreate() \n", + "\n", + "sample_data = [{\"name\": \"John D.\", \"age\": 30}, \n", + " {\"name\": \"Alice G.\", \"age\": 25}, \n", + " {\"name\": \"Bob T.\", \"age\": 35}, \n", + " {\"name\": \"Eve A.\", \"age\": 28}] \n", + "\n", + "df = spark.createDataFrame(sample_data)\n", + "\n", + "# Define DataFrame transformation function\n", + "def remove_extra_spaces(df, column_name):\n", + " # Remove extra spaces from the specified column using regexp_replace\n", + " df_transformed = df.withColumn(column_name, regexp_replace(col(column_name), \"\\\\s+\", \" \"))\n", + "\n", + " return df_transformed" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "248aede2-feb9-4828-bd9c-8e25e6b194ab", + "metadata": {}, + "outputs": [], + "source": [ + "# pkg/test_etl.py\n", + "import unittest\n", + "\n", + "from pyspark.sql import SparkSession \n", + "\n", + "# Define unit test base class\n", + "class PySparkTestCase(unittest.TestCase):\n", + " @classmethod\n", + " def setUpClass(cls):\n", + " cls.spark = SparkSession.builder.appName(\"Sample PySpark ETL\").getOrCreate() \n", + "\n", + " @classmethod\n", + " def tearDownClass(cls):\n", + " cls.spark.stop()\n", + " \n", + "# Define unit test\n", + "class TestTranformation(PySparkTestCase):\n", + " def test_single_space(self):\n", + " sample_data = [{\"name\": \"John D.\", \"age\": 30}, \n", + " {\"name\": \"Alice G.\", \"age\": 25}, \n", + " {\"name\": \"Bob T.\", \"age\": 35}, \n", + " {\"name\": \"Eve A.\", \"age\": 28}] \n", + " \n", + " # Create a Spark DataFrame\n", + " original_df = spark.createDataFrame(sample_data)\n", + " \n", + " # Apply the transformation function from before\n", + " transformed_df = remove_extra_spaces(original_df, \"name\")\n", + " \n", + " expected_data = [{\"name\": \"John D.\", \"age\": 30}, \n", + " {\"name\": \"Alice G.\", \"age\": 25}, \n", + " {\"name\": \"Bob T.\", \"age\": 35}, \n", + " {\"name\": \"Eve A.\", \"age\": 28}]\n", + " \n", + " expected_df = spark.createDataFrame(expected_data)\n", + " \n", + " assertDataFrameEqual(transformed_df, expected_df)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "a77df5b2-f32e-4d8c-a64b-0078dfa21217", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ran 1 test in 1.734s\n", + "\n", + "OK\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "unittest.main(argv=[''], verbosity=0, exit=False)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "jupyter-oss-env", + "language": "python", + "name": "jupyter-oss-env" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 9462dcd0e996dd940d4970dc75482f7d088ac2ae Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Mon, 7 Aug 2023 09:07:48 +0900 Subject: [PATCH 42/68] [SPARK-41636][SQL] Make sure `selectFilters` returns predicates in deterministic order ### What changes were proposed in this pull request? Method `DataSourceStrategy#selectFilters`, which is used to determine "pushdown-able" filters, does not preserve the order of the input Seq[Expression] nor does it return the same order across the same plans. This is resulting in CodeGenerator cache misses even when the exact same LogicalPlan is executed. This PR to make sure `selectFilters` returns predicates in deterministic order. ### Why are the changes needed? Make sure `selectFilters` returns predicates in deterministic order, to reduce the probability of codegen cache misses. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? add new test. Closes #42265 from Hisoka-X/SPARK-41636_selectfilters_order. Authored-by: Jia Fan Signed-off-by: Hyukjin Kwon --- .../execution/datasources/DataSourceStrategy.scala | 6 ++++-- .../datasources/DataSourceStrategySuite.scala | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 5e6e0ad039258..94c2d2ffaca59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources import java.util.Locale +import scala.collection.immutable.ListMap import scala.collection.mutable import org.apache.hadoop.fs.Path @@ -670,9 +671,10 @@ object DataSourceStrategy // A map from original Catalyst expressions to corresponding translated data source filters. // If a predicate is not in this map, it means it cannot be pushed down. val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation) - val translatedMap: Map[Expression, Filter] = predicates.flatMap { p => + // SPARK-41636: we keep the order of the predicates to avoid CodeGenerator cache misses + val translatedMap: Map[Expression, Filter] = ListMap(predicates.flatMap { p => translateFilter(p, supportNestedPredicatePushdown).map(f => p -> f) - }.toMap + }: _*) val pushedFilters: Seq[Filter] = translatedMap.values.toSeq diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala index a35fb5f627145..2b9ec97bace1e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala @@ -324,4 +324,18 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { DataSourceStrategy.translateFilter(catalystFilter, true) } } + + test("SPARK-41636: selectFilters returns predicates in deterministic order") { + + val predicates = Seq(EqualTo($"id", 1), EqualTo($"id", 2), + EqualTo($"id", 3), EqualTo($"id", 4), EqualTo($"id", 5), EqualTo($"id", 6)) + + val (unhandledPredicates, pushedFilters, handledFilters) = + DataSourceStrategy.selectFilters(FakeRelation(), predicates) + assert(unhandledPredicates.equals(predicates)) + assert(pushedFilters.zipWithIndex.forall { case (f, i) => + f.equals(sources.EqualTo("id", i + 1)) + }) + assert(handledFilters.isEmpty) + } } From 656bf36363c466b60d00452399994ccaaa654ed8 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Mon, 7 Aug 2023 10:13:59 +0800 Subject: [PATCH 43/68] [SPARK-44690][SPARK-44376][BUILD] Downgrade Scala to 2.13.8 ### What changes were proposed in this pull request? The aim of this PR is to downgrade the Scala 2.13 dependency to 2.13.8 to ensure that Spark can be build with `-target:jvm-1.8`, and tested with Java 11/17. ### Why are the changes needed? As reported in SPARK-44376, there are issues when maven build and test using Java 11/17 with `-target:jvm-1.8`: - run `build/mvn clean install -Pscala-2.13` with Java 17 ``` [INFO] --- scala-maven-plugin:4.8.0:compile (scala-compile-first) spark-core_2.13 --- [INFO] Compiler bridge file: /Users/yangjie01/.sbt/1.0/zinc/org.scala-sbt/org.scala-sbt-compiler-bridge_2.13-1.8.0-bin_2.13.11__61.0-1.8.0_20221110T195421.jar [INFO] compiling 602 Scala sources and 77 Java sources to /Users/yangjie01/SourceCode/git/spark-mine-13/core/target/scala-2.13/classes ... [WARNING] [Warn] : [deprecation | origin= | version=] -target is deprecated: Use -release instead to compile against the correct platform API. [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala:71: not found: value sun [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala:26: not found: object sun [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala:27: not found: object sun [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala:206: not found: type DirectBuffer [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala:210: not found: type Unsafe [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala:212: not found: type Unsafe [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala:213: not found: type DirectBuffer [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala:216: not found: type DirectBuffer [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala:236: not found: type DirectBuffer [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala:26: Unused import [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala:27: Unused import [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala:452: not found: value sun [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/util/SignalUtils.scala:26: not found: object sun [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/util/SignalUtils.scala:99: not found: type SignalHandler [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/util/SignalUtils.scala:99: not found: type Signal [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/util/SignalUtils.scala:83: not found: type Signal [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/util/SignalUtils.scala:108: not found: type SignalHandler [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/util/SignalUtils.scala:108: not found: value Signal [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/util/SignalUtils.scala:114: not found: type Signal [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/util/SignalUtils.scala:116: not found: value Signal [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/util/SignalUtils.scala:128: not found: value Signal [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/util/SignalUtils.scala:26: Unused import [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/util/SignalUtils.scala:26: Unused import [WARNING] one warning found [ERROR] 23 errors found ``` - run `build/mvn clean install -Pscala-2.13 -Djava.version=17` with Java 17 ``` [INFO] --- scala-maven-plugin:4.8.0:compile (scala-compile-first) spark-tags_2.13 --- [INFO] Compiler bridge file: /Users/yangjie01/.sbt/1.0/zinc/org.scala-sbt/org.scala-sbt-compiler-bridge_2.13-1.8.0-bin_2.13.11__61.0-1.8.0_20221110T195421.jar [INFO] compiling 2 Scala sources and 8 Java sources to /Users/yangjie01/SourceCode/git/spark-mine-13/common/tags/target/scala-2.13/classes ... [WARNING] [Warn] : [deprecation | origin= | version=] -target is deprecated: Use -release instead to compile against the correct platform API. [ERROR] [Error] : target platform version 8 is older than the release version 17 [WARNING] one warning found [ERROR] one error found ``` - run `build/mvn clean package -Pscala-2.13 -DskipTests` or `build/mvn clean install -Pscala-2.13 -DskipTests` with Java 8 first, then run `build/mvn test -Pscala-2.13` with Java 17 ``` [INFO] --- scala-maven-plugin:4.8.0:compile (scala-compile-first) spark-core_2.13 --- [INFO] Compiler bridge file: /Users/yangjie01/.sbt/1.0/zinc/org.scala-sbt/org.scala-sbt-compiler-bridge_2.13-1.8.0-bin_2.13.11__61.0-1.8.0_20221110T195421.jar [INFO] compiling 602 Scala sources and 77 Java sources to /Users/yangjie01/SourceCode/git/spark-mine-13/core/target/scala-2.13/classes ... [WARNING] [Warn] : [deprecation | origin= | version=] -target is deprecated: Use -release instead to compile against the correct platform API. [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala:71: not found: value sun [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala:26: not found: object sun [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala:27: not found: object sun [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala:206: not found: type DirectBuffer [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala:210: not found: type Unsafe [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala:212: not found: type Unsafe [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala:213: not found: type DirectBuffer [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala:216: not found: type DirectBuffer [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala:236: not found: type DirectBuffer [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala:26: Unused import [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala:27: Unused import [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala:452: not found: value sun [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/util/SignalUtils.scala:26: not found: object sun [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/util/SignalUtils.scala:99: not found: type SignalHandler [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/util/SignalUtils.scala:99: not found: type Signal [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/util/SignalUtils.scala:83: not found: type Signal [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/util/SignalUtils.scala:108: not found: type SignalHandler [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/util/SignalUtils.scala:108: not found: value Signal [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/util/SignalUtils.scala:114: not found: type Signal [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/util/SignalUtils.scala:116: not found: value Signal [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/util/SignalUtils.scala:128: not found: value Signal [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/util/SignalUtils.scala:26: Unused import [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/util/SignalUtils.scala:26: Unused import [WARNING] one warning found [ERROR] 23 errors found ``` This is inconsistent with the behavior of the released `Apache Spark` version, so we need to use the previous Scala2.13 version to support this behavior. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - Pass GitHub Actions - Manual checked, the above command can run normally after this pr Closes #42364 from LuciferYang/SPARK-44690. Authored-by: yangjie01 Signed-off-by: yangjie01 --- pom.xml | 6 +----- project/SparkBuild.scala | 4 +--- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/pom.xml b/pom.xml index b0d97c2aa0501..76e3596edd430 100644 --- a/pom.xml +++ b/pom.xml @@ -3600,7 +3600,7 @@ scala-2.13 - 2.13.11 + 2.13.8 2.13 @@ -3659,10 +3659,6 @@ --> -Wconf:cat=unused-imports&src=org\/apache\/spark\/graphx\/impl\/VertexPartitionBase.scala:s -Wconf:cat=unused-imports&src=org\/apache\/spark\/graphx\/impl\/VertexPartitionBaseOps.scala:s - - -Wconf:msg=Implicit definition should have explicit type:s diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index e585d5dd2b25c..7900762602689 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -286,9 +286,7 @@ object SparkBuild extends PomBuild { // TODO(SPARK-43850): Remove the following suppression rules and remove `import scala.language.higherKinds` // from the corresponding files when Scala 2.12 is no longer supported. "-Wconf:cat=unused-imports&src=org\\/apache\\/spark\\/graphx\\/impl\\/VertexPartitionBase.scala:s", - "-Wconf:cat=unused-imports&src=org\\/apache\\/spark\\/graphx\\/impl\\/VertexPartitionBaseOps.scala:s", - // SPARK-40497 Upgrade Scala to 2.13.11 and suppress `Implicit definition should have explicit type` - "-Wconf:msg=Implicit definition should have explicit type:s" + "-Wconf:cat=unused-imports&src=org\\/apache\\/spark\\/graphx\\/impl\\/VertexPartitionBaseOps.scala:s" ) } } From a98b11274d95f7c9f6e550ef6394e803bc0c17ca Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Mon, 7 Aug 2023 10:27:53 +0800 Subject: [PATCH 44/68] [SPARK-44693][BUILD] Rename the `object Catalyst` in SparkBuild to `object SqlApi` ### What changes were proposed in this pull request? This PR renames the Setting object used by the `SqlApi` module in `SparkBuild/scala` from `object Catalyst` to `object SqlApi`. ### Why are the changes needed? The `SqlApi` module should use a more appropriate Setting object name. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Pass GitHub Actions Closes #42361 from LuciferYang/rename-catalyst-2-sqlapi. Authored-by: yangjie01 Signed-off-by: yangjie01 --- project/SparkBuild.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 7900762602689..bd65d3c4bd4aa 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -449,7 +449,7 @@ object SparkBuild extends PomBuild { enable(Unidoc.settings)(spark) /* Sql-api ANTLR generation settings */ - enable(Catalyst.settings)(sqlApi) + enable(SqlApi.settings)(sqlApi) /* Spark SQL Core console settings */ enable(SQL.settings)(sql) @@ -1169,7 +1169,7 @@ object OldDeps { ) } -object Catalyst { +object SqlApi { import com.simplytyped.Antlr4Plugin import com.simplytyped.Antlr4Plugin.autoImport._ From 7515061ec237b0393e2fcc064a309fae29502dff Mon Sep 17 00:00:00 2001 From: Mathew Jacob Date: Mon, 7 Aug 2023 10:57:02 +0800 Subject: [PATCH 45/68] [SPARK-44264][PYTHON][ML][TESTS][FOLLOWUP] Adding Deepspeed To The Test Dockerfile ### What changes were proposed in this pull request? Added tests to the Dockerfile for tests in OSS Spark CI. ### Why are the changes needed? They'll skip the deepspeed tests otherwise. ### Does this PR introduce _any_ user-facing change? Nope, testing infra. ### How was this patch tested? Running the tests on machine. Closes #42347 from mathewjacob1002/testing_infra. Authored-by: Mathew Jacob Signed-off-by: Ruifeng Zheng --- dev/infra/Dockerfile | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dev/infra/Dockerfile b/dev/infra/Dockerfile index 9d7b29e25b49b..b69e682f239c8 100644 --- a/dev/infra/Dockerfile +++ b/dev/infra/Dockerfile @@ -73,3 +73,5 @@ RUN python3.9 -m pip install grpcio protobuf googleapis-common-protos grpcio-sta # Add torch as a testing dependency for TorchDistributor RUN python3.9 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu RUN python3.9 -m pip install torcheval +# Add Deepspeed as a testing dependency for DeepspeedTorchDistributor +RUN python3.9 -m pip install deepspeed From df8e52d84d1eabf48f68d09491f66a0835f41693 Mon Sep 17 00:00:00 2001 From: Amanda Liu Date: Mon, 7 Aug 2023 12:01:45 +0900 Subject: [PATCH 46/68] [SPARK-44682][PS] Make pandas error class message_parameters strings ### What changes were proposed in this pull request? This PR converts the types for message_parameters for pandas error classes to string, to ensure ability to compare error class messages in tests. ### Why are the changes needed? The change ensures the ability to compare error class messages in tests. ### Does this PR introduce _any_ user-facing change? No, the PR does not affect the user-facing view of the error messages. ### How was this patch tested? Updated `python/pyspark/pandas/tests/test_utils.py` and existing tests Closes #42348 from asl3/string-pandas-error-types. Authored-by: Amanda Liu Signed-off-by: Hyukjin Kwon --- python/pyspark/pandas/tests/test_utils.py | 16 +++---- python/pyspark/testing/pandasutils.py | 56 +++++++++++------------ 2 files changed, 36 insertions(+), 36 deletions(-) diff --git a/python/pyspark/pandas/tests/test_utils.py b/python/pyspark/pandas/tests/test_utils.py index 3d658446f2766..0bb03dd8749da 100644 --- a/python/pyspark/pandas/tests/test_utils.py +++ b/python/pyspark/pandas/tests/test_utils.py @@ -208,10 +208,10 @@ def test_series_error_assert_pandas_equal(self): exception=pe.exception, error_class="DIFFERENT_PANDAS_SERIES", message_parameters={ - "left": series1, - "left_dtype": series1.dtype, - "right": series2, - "right_dtype": series2.dtype, + "left": series1.to_string(), + "left_dtype": str(series1.dtype), + "right": series2.to_string(), + "right_dtype": str(series2.dtype), }, ) @@ -227,9 +227,9 @@ def test_index_error_assert_pandas_equal(self): error_class="DIFFERENT_PANDAS_INDEX", message_parameters={ "left": index1, - "left_dtype": index1.dtype, + "left_dtype": str(index1.dtype), "right": index2, - "right_dtype": index2.dtype, + "right_dtype": str(index2.dtype), }, ) @@ -247,9 +247,9 @@ def test_multiindex_error_assert_pandas_almost_equal(self): error_class="DIFFERENT_PANDAS_MULTIINDEX", message_parameters={ "left": multiindex1, - "left_dtype": multiindex1.dtype, + "left_dtype": str(multiindex1.dtype), "right": multiindex2, - "right_dtype": multiindex2.dtype, + "right_dtype": str(multiindex1.dtype), }, ) diff --git a/python/pyspark/testing/pandasutils.py b/python/pyspark/testing/pandasutils.py index 5899925352144..39196873482b1 100644 --- a/python/pyspark/testing/pandasutils.py +++ b/python/pyspark/testing/pandasutils.py @@ -124,10 +124,10 @@ def _assert_pandas_equal( raise PySparkAssertionError( error_class="DIFFERENT_PANDAS_SERIES", message_parameters={ - "left": left, - "left_dtype": left.dtype, - "right": right, - "right_dtype": right.dtype, + "left": left.to_string(), + "left_dtype": str(left.dtype), + "right": right.to_string(), + "right_dtype": str(right.dtype), }, ) elif isinstance(left, pd.Index) and isinstance(right, pd.Index): @@ -143,9 +143,9 @@ def _assert_pandas_equal( error_class="DIFFERENT_PANDAS_INDEX", message_parameters={ "left": left, - "left_dtype": left.dtype, + "left_dtype": str(left.dtype), "right": right, - "right_dtype": right.dtype, + "right_dtype": str(right.dtype), }, ) else: @@ -228,10 +228,10 @@ def _assert_pandas_almost_equal( raise PySparkAssertionError( error_class="DIFFERENT_PANDAS_SERIES", message_parameters={ - "left": left, - "left_dtype": left.dtype, - "right": right, - "right_dtype": right.dtype, + "left": left.to_string(), + "left_dtype": str(left.dtype), + "right": right.to_string(), + "right_dtype": str(right.dtype), }, ) for lnull, rnull in zip(left.isnull(), right.isnull()): @@ -239,10 +239,10 @@ def _assert_pandas_almost_equal( raise PySparkAssertionError( error_class="DIFFERENT_PANDAS_SERIES", message_parameters={ - "left": left, - "left_dtype": left.dtype, - "right": right, - "right_dtype": right.dtype, + "left": left.to_string(), + "left_dtype": str(left.dtype), + "right": right.to_string(), + "right_dtype": str(right.dtype), }, ) for lval, rval in zip(left.dropna(), right.dropna()): @@ -253,10 +253,10 @@ def _assert_pandas_almost_equal( raise PySparkAssertionError( error_class="DIFFERENT_PANDAS_SERIES", message_parameters={ - "left": left, - "left_dtype": left.dtype, - "right": right, - "right_dtype": right.dtype, + "left": left.to_string(), + "left_dtype": str(left.dtype), + "right": right.to_string(), + "right_dtype": str(right.dtype), }, ) elif isinstance(left, pd.MultiIndex) and isinstance(right, pd.MultiIndex): @@ -265,9 +265,9 @@ def _assert_pandas_almost_equal( error_class="DIFFERENT_PANDAS_MULTIINDEX", message_parameters={ "left": left, - "left_dtype": left.dtype, + "left_dtype": str(left.dtype), "right": right, - "right_dtype": right.dtype, + "right_dtype": str(right.dtype), }, ) for lval, rval in zip(left, right): @@ -279,9 +279,9 @@ def _assert_pandas_almost_equal( error_class="DIFFERENT_PANDAS_MULTIINDEX", message_parameters={ "left": left, - "left_dtype": left.dtype, + "left_dtype": str(left.dtype), "right": right, - "right_dtype": right.dtype, + "right_dtype": str(right.dtype), }, ) elif isinstance(left, pd.Index) and isinstance(right, pd.Index): @@ -290,9 +290,9 @@ def _assert_pandas_almost_equal( error_class="DIFFERENT_PANDAS_INDEX", message_parameters={ "left": left, - "left_dtype": left.dtype, + "left_dtype": str(left.dtype), "right": right, - "right_dtype": right.dtype, + "right_dtype": str(right.dtype), }, ) for lnull, rnull in zip(left.isnull(), right.isnull()): @@ -301,9 +301,9 @@ def _assert_pandas_almost_equal( error_class="DIFFERENT_PANDAS_INDEX", message_parameters={ "left": left, - "left_dtype": left.dtype, + "left_dtype": str(left.dtype), "right": right, - "right_dtype": right.dtype, + "right_dtype": str(right.dtype), }, ) for lval, rval in zip(left.dropna(), right.dropna()): @@ -315,9 +315,9 @@ def _assert_pandas_almost_equal( error_class="DIFFERENT_PANDAS_INDEX", message_parameters={ "left": left, - "left_dtype": left.dtype, + "left_dtype": str(left.dtype), "right": right, - "right_dtype": right.dtype, + "right_dtype": str(right.dtype), }, ) else: From 1f10cc4a59457ed0de0fd4dc0a1c61514d77261a Mon Sep 17 00:00:00 2001 From: panbingkun Date: Mon, 7 Aug 2023 12:01:47 +0500 Subject: [PATCH 47/68] [SPARK-44628][SQL] Clear some unused codes in "***Errors" and extract some common logic ### What changes were proposed in this pull request? The pr aims to clear some unused codes in "***Errors" and extract some common logic. ### Why are the changes needed? Make code clear. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. Closes #42238 from panbingkun/clear_error. Authored-by: panbingkun Signed-off-by: Max Gekk --- .../spark/sql/errors/DataTypeErrors.scala | 18 ++-- .../spark/sql/errors/QueryErrorsBase.scala | 6 +- .../sql/errors/QueryExecutionErrors.scala | 86 ------------------- 3 files changed, 10 insertions(+), 100 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala index 7a34a386cd889..5e52e283338d3 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala @@ -192,15 +192,7 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { decimalPrecision: Int, decimalScale: Int, context: SQLQueryContext = null): ArithmeticException = { - new SparkArithmeticException( - errorClass = "NUMERIC_VALUE_OUT_OF_RANGE", - messageParameters = Map( - "value" -> value.toPlainString, - "precision" -> decimalPrecision.toString, - "scale" -> decimalScale.toString, - "config" -> toSQLConf("spark.sql.ansi.enabled")), - context = getQueryContext(context), - summary = getSummary(context)) + numericValueOutOfRange(value, decimalPrecision, decimalScale, context) } def cannotChangeDecimalPrecisionError( @@ -208,6 +200,14 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { decimalPrecision: Int, decimalScale: Int, context: SQLQueryContext = null): ArithmeticException = { + numericValueOutOfRange(value, decimalPrecision, decimalScale, context) + } + + private def numericValueOutOfRange( + value: Decimal, + decimalPrecision: Int, + decimalScale: Int, + context: SQLQueryContext): ArithmeticException = { new SparkArithmeticException( errorClass = "NUMERIC_VALUE_OUT_OF_RANGE", messageParameters = Map( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryErrorsBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryErrorsBase.scala index db256fbee8785..26600117a0c54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryErrorsBase.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryErrorsBase.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.errors import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} -import org.apache.spark.sql.catalyst.util.{toPrettySQL, QuotingUtils} +import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.types.{DataType, DoubleType, FloatType} /** @@ -55,10 +55,6 @@ private[sql] trait QueryErrorsBase extends DataTypeErrorsBase { quoteByDefault(toPrettySQL(e)) } - def toSQLSchema(schema: String): String = { - QuotingUtils.toSQLSchema(schema) - } - // Converts an error class parameter to its SQL representation def toSQLValue(v: Any, t: DataType): String = Literal.create(v, t) match { case Literal(null, _) => "NULL" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 45b5d6b6692cf..f960a091ec0f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -32,7 +32,6 @@ import org.apache.spark._ import org.apache.spark.launcher.SparkLauncher import org.apache.spark.memory.SparkOutOfMemoryError import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.ScalaReflection.Schema import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable} @@ -183,10 +182,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE messageParameters = Map.empty) } - def dataTypeUnsupportedError(dataType: String, failure: String): Throwable = { - DataTypeErrors.dataTypeUnsupportedError(dataType, failure) - } - def failedExecuteUserDefinedFunctionError(functionName: String, inputTypes: String, outputType: String, e: Throwable): Throwable = { new SparkException( @@ -503,10 +498,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE messageParameters = Map("op" -> op.toString(), "pos" -> pos)) } - def unsupportedRoundingMode(roundMode: BigDecimal.RoundingMode.Value): SparkException = { - DataTypeErrors.unsupportedRoundingMode(roundMode) - } - def resolveCannotHandleNestedSchema(plan: LogicalPlan): SparkRuntimeException = { new SparkRuntimeException( errorClass = "_LEGACY_ERROR_TEMP_2030", @@ -1214,52 +1205,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE messageParameters = Map("o" -> o.toString())) } - def unscaledValueTooLargeForPrecisionError( - value: Decimal, - decimalPrecision: Int, - decimalScale: Int, - context: SQLQueryContext = null): ArithmeticException = { - DataTypeErrors.unscaledValueTooLargeForPrecisionError( - value, decimalPrecision, decimalScale, context) - } - - def decimalPrecisionExceedsMaxPrecisionError( - precision: Int, maxPrecision: Int): SparkArithmeticException = { - DataTypeErrors.decimalPrecisionExceedsMaxPrecisionError(precision, maxPrecision) - } - - def outOfDecimalTypeRangeError(str: UTF8String): SparkArithmeticException = { - new SparkArithmeticException( - errorClass = "NUMERIC_OUT_OF_SUPPORTED_RANGE", - messageParameters = Map( - "value" -> str.toString), - context = Array.empty, - summary = "") - } - - def unsupportedArrayTypeError(clazz: Class[_]): SparkRuntimeException = { - DataTypeErrors.unsupportedJavaTypeError(clazz) - } - - def unsupportedJavaTypeError(clazz: Class[_]): SparkRuntimeException = { - DataTypeErrors.unsupportedJavaTypeError(clazz) - } - - def failedParsingStructTypeError(raw: String): SparkRuntimeException = { - new SparkRuntimeException( - errorClass = "FAILED_PARSE_STRUCT_TYPE", - messageParameters = Map("raw" -> toSQLValue(raw, StringType))) - } - - def cannotMergeDecimalTypesWithIncompatibleScaleError( - leftScale: Int, rightScale: Int): Throwable = { - DataTypeErrors.cannotMergeDecimalTypesWithIncompatibleScaleError(leftScale, rightScale) - } - - def cannotMergeIncompatibleDataTypesError(left: DataType, right: DataType): Throwable = { - DataTypeErrors.cannotMergeIncompatibleDataTypesError(left, right) - } - def exceedMapSizeLimitError(size: Int): SparkRuntimeException = { new SparkRuntimeException( errorClass = "_LEGACY_ERROR_TEMP_2126", @@ -1344,13 +1289,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE messageParameters = Map.empty) } - def attributesForTypeUnsupportedError(schema: Schema): SparkUnsupportedOperationException = { - new SparkUnsupportedOperationException( - errorClass = "_LEGACY_ERROR_TEMP_2142", - messageParameters = Map( - "schema" -> schema.toString())) - } - def paramExceedOneCharError(paramName: String): SparkRuntimeException = { new SparkRuntimeException( errorClass = "_LEGACY_ERROR_TEMP_2145", @@ -2004,22 +1942,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE cause = null) } - def unsupportedOperationExceptionError(): SparkUnsupportedOperationException = { - DataTypeErrors.unsupportedOperationExceptionError() - } - - def nullLiteralsCannotBeCastedError(name: String): SparkUnsupportedOperationException = { - DataTypeErrors.nullLiteralsCannotBeCastedError(name) - } - - def notUserDefinedTypeError(name: String, userClass: String): Throwable = { - DataTypeErrors.notUserDefinedTypeError(name, userClass) - } - - def cannotLoadUserDefinedTypeError(name: String, userClass: String): Throwable = { - DataTypeErrors.cannotLoadUserDefinedTypeError(name, userClass) - } - def notPublicClassError(name: String): SparkUnsupportedOperationException = { new SparkUnsupportedOperationException( errorClass = "_LEGACY_ERROR_TEMP_2229", @@ -2033,14 +1955,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE messageParameters = Map.empty) } - def fieldIndexOnRowWithoutSchemaError(): SparkUnsupportedOperationException = { - DataTypeErrors.fieldIndexOnRowWithoutSchemaError() - } - - def valueIsNullError(index: Int): Throwable = { - DataTypeErrors.valueIsNullError(index) - } - def onlySupportDataSourcesProvidingFileFormatError(providingClass: String): Throwable = { new SparkException( errorClass = "_LEGACY_ERROR_TEMP_2233", From f139733b92d421233a3fd35374236ef084dfd10d Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Mon, 7 Aug 2023 12:13:57 +0500 Subject: [PATCH 48/68] [SPARK-42321][SQL] Assign name to _LEGACY_ERROR_TEMP_2133 ### What changes were proposed in this pull request? This PR proposes to assign name to _LEGACY_ERROR_TEMP_2133, "CANNOT_PARSE_STRING_AS_DATATYPE". ### Why are the changes needed? Assign name to _LEGACY_ERROR_TEMP_2133 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? add new test Closes #42018 from Hisoka-X/SPARK-42321_LEGACY_ERROR_TEMP_2133. Authored-by: Jia Fan Signed-off-by: Max Gekk --- .../main/resources/error/error-classes.json | 10 +++++----- ...malformed-record-in-parsing-error-class.md | 4 ++++ .../sql/catalyst/json/JacksonParser.scala | 8 ++++---- .../catalyst/util/BadRecordException.scala | 9 +++++++++ .../sql/catalyst/util/FailureSafeParser.scala | 3 +++ .../sql/errors/QueryExecutionErrors.scala | 19 ++++++++++++------- .../errors/QueryExecutionErrorsSuite.scala | 17 +++++++++++++++++ 7 files changed, 54 insertions(+), 16 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 20f2ab4eb24eb..680f787429c70 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -2043,6 +2043,11 @@ "Parsing JSON arrays as structs is forbidden." ] }, + "CANNOT_PARSE_STRING_AS_DATATYPE" : { + "message" : [ + "Cannot parse the value of the field as target spark data type from the input type ." + ] + }, "WITHOUT_SUGGESTION" : { "message" : [ "" @@ -5318,11 +5323,6 @@ "Exception when registering StreamingQueryListener." ] }, - "_LEGACY_ERROR_TEMP_2133" : { - "message" : [ - "Cannot parse field name , field value , [] as target spark data type []." - ] - }, "_LEGACY_ERROR_TEMP_2134" : { "message" : [ "Cannot parse field value for pattern as target spark data type []." diff --git a/docs/sql-error-conditions-malformed-record-in-parsing-error-class.md b/docs/sql-error-conditions-malformed-record-in-parsing-error-class.md index ab9582dffcd31..1cc0327af67ba 100644 --- a/docs/sql-error-conditions-malformed-record-in-parsing-error-class.md +++ b/docs/sql-error-conditions-malformed-record-in-parsing-error-class.md @@ -30,6 +30,10 @@ This error class has the following derived error classes: Parsing JSON arrays as structs is forbidden. +## CANNOT_PARSE_STRING_AS_DATATYPE + +Cannot parse the value `` of the field `` as target spark data type `` from the input type ``. + ## WITHOUT_SUGGESTION diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 388edb9024ca1..91c17a475cd94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -231,8 +231,8 @@ class JacksonParser( Float.PositiveInfinity case "-INF" | "-Infinity" if options.allowNonNumericNumbers => Float.NegativeInfinity - case _ => throw QueryExecutionErrors.cannotParseStringAsDataTypeError( - parser, VALUE_STRING, FloatType) + case _ => throw StringAsDataTypeException(parser.getCurrentName, parser.getText, + FloatType) } } @@ -250,8 +250,8 @@ class JacksonParser( Double.PositiveInfinity case "-INF" | "-Infinity" if options.allowNonNumericNumbers => Double.NegativeInfinity - case _ => throw QueryExecutionErrors.cannotParseStringAsDataTypeError( - parser, VALUE_STRING, DoubleType) + case _ => throw StringAsDataTypeException(parser.getCurrentName, parser.getText, + DoubleType) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala index e1223a71f746b..7bf01fba8cd9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.DataType import org.apache.spark.unsafe.types.UTF8String /** @@ -56,3 +57,11 @@ case class BadRecordException( * Exception thrown when the underlying parser parses a JSON array as a struct. */ case class JsonArraysAsStructsException() extends RuntimeException() + +/** + * Exception thrown when the underlying parser can not parses a String as a datatype. + */ +case class StringAsDataTypeException( + fieldName: String, + fieldValue: String, + dataType: DataType) extends RuntimeException() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala index 2a9370b8c91ce..0a5764e21e14e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala @@ -75,6 +75,9 @@ class FailureSafeParser[IN]( // SPARK-42298 we recreate the exception here to make sure the error message // have the record content. throw QueryExecutionErrors.cannotParseJsonArraysAsStructsError(e.record().toString) + case StringAsDataTypeException(fieldName, fieldValue, dataType) => + throw QueryExecutionErrors.cannotParseStringAsDataTypeError(e.record().toString, + fieldName, fieldValue, dataType) case _ => throw QueryExecutionErrors.malformedRecordsDetectedInRecordParsingError( toResultRow(e.partialResults().headOption, e.record).toString, e) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index f960a091ec0f7..f3c5fb4bef3b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -1255,15 +1255,20 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE "failFastMode" -> FailFastMode.name)) } - def cannotParseStringAsDataTypeError(parser: JsonParser, token: JsonToken, dataType: DataType) - : SparkRuntimeException = { + def cannotParseStringAsDataTypeError( + recordStr: String, + fieldName: String, + fieldValue: String, + dataType: DataType): SparkRuntimeException = { new SparkRuntimeException( - errorClass = "_LEGACY_ERROR_TEMP_2133", + errorClass = "MALFORMED_RECORD_IN_PARSING.CANNOT_PARSE_STRING_AS_DATATYPE", messageParameters = Map( - "fieldName" -> parser.getCurrentName, - "fieldValue" -> parser.getText, - "token" -> token.toString(), - "dataType" -> dataType.toString())) + "badRecord" -> recordStr, + "failFastMode" -> FailFastMode.name, + "fieldName" -> toSQLId(fieldName), + "fieldValue" -> toSQLValue(fieldValue, StringType), + "inputType" -> StringType.toString, + "targetType" -> dataType.toString)) } def emptyJsonFieldValueError(dataType: DataType): SparkRuntimeException = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala index ae1c0a86a14c2..fb10e90b6ccea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala @@ -725,6 +725,23 @@ class QueryExecutionErrorsSuite } } + test("CANNOT_PARSE_STRING_AS_DATATYPE: parse string as float use from_json") { + val jsonStr = """{"a": "str"}""" + checkError( + exception = intercept[SparkRuntimeException] { + sql(s"""SELECT from_json('$jsonStr', 'a FLOAT', map('mode','FAILFAST'))""").collect() + }, + errorClass = "MALFORMED_RECORD_IN_PARSING.CANNOT_PARSE_STRING_AS_DATATYPE", + parameters = Map( + "badRecord" -> jsonStr, + "failFastMode" -> "FAILFAST", + "fieldName" -> "`a`", + "fieldValue" -> "'str'", + "inputType" -> "StringType", + "targetType" -> "FloatType"), + sqlState = "22023") + } + test("BINARY_ARITHMETIC_OVERFLOW: byte plus byte result overflow") { withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { checkError( From 433fb2af8a3ca239958cb7b006e2924ecfac0d56 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 7 Aug 2023 20:46:38 +0900 Subject: [PATCH 49/68] [SPARK-44701][PYTHON][TESTS] Skip `ClassificationTestsOnConnect` when `torch` is not installed ### What changes were proposed in this pull request? Skip `ClassificationTestsOnConnect` when `torch` is not installed ### Why are the changes needed? we moved torch on connect tests to `pyspark_ml_connect`, so module `pyspark_connect` won't have `torch` to fix https://github.com/apache/spark/actions/runs/5776211318/job/15655104006 in 3.5 daily GA: ``` Starting test(python3.9): pyspark.ml.tests.connect.test_connect_classification (temp output: /__w/spark/spark/python/target/fbb6a495-df65-4334-8c04-4befc9ee81df/python3.9__pyspark.ml.tests.connect.test_connect_classification__jp1htw6f.log) Traceback (most recent call last): File "/usr/lib/python3.9/runpy.py", line 197, in _run_module_as_main return _run_code(code, main_globals, None, File "/usr/lib/python3.9/runpy.py", line 87, in _run_code exec(code, run_globals) File "/__w/spark/spark/python/pyspark/ml/tests/connect/test_connect_classification.py", line 21, in from pyspark.ml.tests.connect.test_legacy_mode_classification import ClassificationTestsMixin File "/__w/spark/spark/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py", line 22, in from pyspark.ml.connect.classification import ( File "/__w/spark/spark/python/pyspark/ml/connect/classification.py", line 46, in import torch ModuleNotFoundError: No module named 'torch' ``` ### Does this PR introduce _any_ user-facing change? no, test-only ### How was this patch tested? CI Closes #42375 from zhengruifeng/torch_skip. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- .../ml/tests/connect/test_connect_classification.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/pyspark/ml/tests/connect/test_connect_classification.py b/python/pyspark/ml/tests/connect/test_connect_classification.py index 6ad47322234c5..f3e621c19f0f0 100644 --- a/python/pyspark/ml/tests/connect/test_connect_classification.py +++ b/python/pyspark/ml/tests/connect/test_connect_classification.py @@ -20,7 +20,14 @@ from pyspark.sql import SparkSession from pyspark.ml.tests.connect.test_legacy_mode_classification import ClassificationTestsMixin +have_torch = True +try: + import torch # noqa: F401 +except ImportError: + have_torch = False + +@unittest.skipIf(not have_torch, "torch is required") class ClassificationTestsOnConnect(ClassificationTestsMixin, unittest.TestCase): def setUp(self) -> None: self.spark = ( From bf7654998fbbec9d5bdee6f46462cffef495545f Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 7 Aug 2023 15:09:58 +0200 Subject: [PATCH 50/68] [SPARK-44686][CONNECT][SQL] Add the ability to create a RowEncoder in Encoders.scala ### What changes were proposed in this pull request? ### Why are the changes needed? It is currently not possible to create a `RowEncoder` using public API. The internal APIs for this will change in Spark 3.5, this means that library maintainers have to update their code if they use a RowEncoder. To avoid happening again, we add this method to the public API. ### Does this PR introduce _any_ user-facing change? Yes. It adds the `row` method to `Encoders`. ### How was this patch tested? Added tests to connect and sql. Closes #42366 from hvanhovell/SPARK-44686. Lead-authored-by: Herman van Hovell Co-authored-by: Hyukjin Kwon Signed-off-by: Herman van Hovell --- .../scala/org/apache/spark/sql/Encoders.scala | 10 +++++- .../apache/spark/sql/JavaEncoderSuite.java | 31 ++++++++++++++++--- project/MimaExcludes.scala | 2 ++ .../java/org/apache/spark/sql/RowFactory.java | 0 .../scala/org/apache/spark/sql/Encoders.scala | 7 +++++ .../apache/spark/sql/JavaDatasetSuite.java | 19 ++++++++++++ 6 files changed, 64 insertions(+), 5 deletions(-) rename sql/{catalyst => api}/src/main/java/org/apache/spark/sql/RowFactory.java (100%) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala index 3f2f7ec96d4f5..74f0133803137 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder => RowEncoderFactory} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ +import org.apache.spark.sql.types.StructType /** * Methods for creating an [[Encoder]]. @@ -168,6 +169,13 @@ object Encoders { */ def bean[T](beanClass: Class[T]): Encoder[T] = JavaTypeInference.encoderFor(beanClass) + /** + * Creates a [[Row]] encoder for schema `schema`. + * + * @since 3.5.0 + */ + def row(schema: StructType): Encoder[Row] = RowEncoderFactory.encoderFor(schema) + private def tupleEncoder[T](encoders: Encoder[_]*): Encoder[T] = { ProductEncoder.tuple(encoders.asInstanceOf[Seq[AgnosticEncoder[_]]]).asInstanceOf[Encoder[T]] } diff --git a/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java b/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java index c8210a7a485b1..6e5fb72d4964b 100644 --- a/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java +++ b/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java @@ -16,21 +16,26 @@ */ package org.apache.spark.sql; +import java.io.Serializable; +import java.math.BigDecimal; +import java.util.Arrays; +import java.util.List; + import org.junit.*; import static org.junit.Assert.*; import static org.apache.spark.sql.Encoders.*; import static org.apache.spark.sql.functions.*; +import static org.apache.spark.sql.RowFactory.create; import org.apache.spark.sql.connect.client.SparkConnectClient; import org.apache.spark.sql.connect.client.util.SparkConnectServerUtils; - -import java.math.BigDecimal; -import java.util.Arrays; +import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.sql.types.StructType; /** * Tests for the encoders class. */ -public class JavaEncoderSuite { +public class JavaEncoderSuite implements Serializable { private static SparkSession spark; @BeforeClass @@ -91,4 +96,22 @@ public void testSimpleEncoders() { dataset(DECIMAL(), bigDec(1000, 2), bigDec(2, 2)) .select(sum(v)).as(DECIMAL()).head().setScale(2)); } + + @Test + public void testRowEncoder() { + final StructType schema = new StructType() + .add("a", "int") + .add("b", "string"); + final Dataset df = spark.range(3) + .map(new MapFunction() { + @Override + public Row call(Long i) { + return create(i.intValue(), "s" + i); + } + }, + Encoders.row(schema)) + .filter(col("a").geq(1)); + final List expected = Arrays.asList(create(1, "s1"), create(2, "s2")); + Assert.assertEquals(expected, df.collectAsList()); + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index d0fc8f2b11655..9e5eb66ce94d0 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -71,6 +71,8 @@ object MimaExcludes { // [SPARK-44507][SQL][CONNECT] Move AnalysisException to sql/api. ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.AnalysisException"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.AnalysisException$"), + // [SPARK-44686][CONNECT][SQL] Add the ability to create a RowEncoder in Encoders + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.RowFactory"), // [SPARK-44535][CONNECT][SQL] Move required Streaming API to sql/api ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.GroupStateTimeout"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.OutputMode"), diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java b/sql/api/src/main/java/org/apache/spark/sql/RowFactory.java similarity index 100% rename from sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java rename to sql/api/src/main/java/org/apache/spark/sql/RowFactory.java diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index a419804488654..9b95f74db3a49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -178,6 +178,13 @@ object Encoders { */ def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder.javaBean(beanClass) + /** + * Creates a [[Row]] encoder for schema `schema`. + * + * @since 3.5.0 + */ + def row(schema: StructType): Encoder[Row] = ExpressionEncoder(schema) + /** * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. * This encoder maps T into a single byte array (binary) field. diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 48fd009d6e70f..4f7cf8da78722 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -42,6 +42,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.*; import org.apache.spark.sql.*; +import static org.apache.spark.sql.RowFactory.create; import org.apache.spark.sql.catalyst.encoders.OuterScopes; import org.apache.spark.sql.catalyst.expressions.GenericRow; import org.apache.spark.sql.test.TestSparkSession; @@ -1956,6 +1957,24 @@ public void testSpecificLists() { Assert.assertEquals(beans, dataset.collectAsList()); } + @Test + public void testRowEncoder() { + final StructType schema = new StructType() + .add("a", "int") + .add("b", "string"); + final Dataset df = spark.range(3) + .map(new MapFunction() { + @Override + public Row call(Long i) { + return create(i.intValue(), "s" + i); + } + }, + Encoders.row(schema)) + .filter(col("a").geq(1)); + final List expected = Arrays.asList(create(1, "s1"), create(2, "s2")); + Assert.assertEquals(expected, df.collectAsList()); + } + public static class SpecificListsBean implements Serializable { private ArrayList arrayList; private LinkedList linkedList; From a3a32912be04d3760cb34eb4b79d6d481bbec502 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Mon, 7 Aug 2023 21:42:38 +0800 Subject: [PATCH 51/68] [SPARK-44697][CORE] Clean up the deprecated usage of `o.a.commons.lang3.RandomUtils` ### What changes were proposed in this pull request? In `commons-lang3` 3.13.0, `RandomUtils` has been marked as `Deprecated`, the Java doc of `commons-lang3` suggests to instead use the api of `commons-rng`. https://github.com/apache/commons-lang/blob/bcc10b359318397a4d12dbaef22b101725bc6323/src/main/java/org/apache/commons/lang3/RandomUtils.java#L33 ``` * deprecated Use Apache Commons RNG's optimized UniformRandomProvider ``` However, as Spark only uses `RandomUtils` in test code, so this pr attempts to replace `RandomUtils` with `ThreadLocalRandom` to avoid introducing additional third-party dependencies. ### Why are the changes needed? Clean up the use of Deprecated api. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass GitHub Actions Closes #42370 from LuciferYang/RandomUtils-2-ThreadLocalRandom. Authored-by: yangjie01 Signed-off-by: yangjie01 --- .../org/apache/spark/io/GenericFileInputStreamSuite.java | 8 ++++---- .../spark/deploy/master/PersistenceEngineSuite.scala | 4 ++-- .../apache/spark/metrics/InputOutputMetricsSuite.scala | 4 ++-- .../org/apache/spark/storage/BlockManagerSuite.scala | 6 +++--- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java b/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java index ef7c4cbbb799c..4bfb4a2c68c40 100644 --- a/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java +++ b/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.io; import org.apache.commons.io.FileUtils; -import org.apache.commons.lang3.RandomUtils; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -25,6 +24,7 @@ import java.io.File; import java.io.IOException; import java.io.InputStream; +import java.util.concurrent.ThreadLocalRandom; import static org.junit.Assert.assertEquals; @@ -33,7 +33,8 @@ */ public abstract class GenericFileInputStreamSuite { - private byte[] randomBytes; + // Create a byte array of size 2 MB with random bytes + private byte[] randomBytes = new byte[2 * 1024 * 1024]; protected File inputFile; @@ -41,8 +42,7 @@ public abstract class GenericFileInputStreamSuite { @Before public void setUp() throws IOException { - // Create a byte array of size 2 MB with random bytes - randomBytes = RandomUtils.nextBytes(2 * 1024 * 1024); + ThreadLocalRandom.current().nextBytes(randomBytes); inputFile = File.createTempFile("temp-file", ".tmp"); FileUtils.writeByteArrayToFile(inputFile, randomBytes); } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala index 39607621b4c45..998ad21a50d25 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala @@ -19,8 +19,8 @@ package org.apache.spark.deploy.master import java.net.ServerSocket +import java.util.concurrent.ThreadLocalRandom -import org.apache.commons.lang3.RandomUtils import org.apache.curator.test.TestingServer import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} @@ -117,7 +117,7 @@ class PersistenceEngineSuite extends SparkFunSuite { } private def findFreePort(conf: SparkConf): Int = { - val candidatePort = RandomUtils.nextInt(1024, 65536) + val candidatePort = ThreadLocalRandom.current().nextInt(1024, 65536) Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { val socket = new ServerSocket(trialPort) socket.close() diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 905bb8110736d..3e69f01c09c46 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -18,10 +18,10 @@ package org.apache.spark.metrics import java.io.{File, PrintWriter} +import java.util.concurrent.ThreadLocalRandom import scala.collection.mutable.ArrayBuffer -import org.apache.commons.lang3.RandomUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{LongWritable, Text} @@ -54,7 +54,7 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext Utils.tryWithResource(new PrintWriter(tmpFile)) { pw => for (x <- 1 to numRecords) { // scalastyle:off println - pw.println(RandomUtils.nextInt(0, numBuckets)) + pw.println(ThreadLocalRandom.current().nextInt(0, numBuckets)) // scalastyle:on println } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index ecd66dc2c5fb0..dcb69f812a7db 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.storage import java.io.{File, InputStream, IOException} import java.nio.ByteBuffer import java.nio.file.Files +import java.util.concurrent.ThreadLocalRandom import scala.collection.JavaConverters._ import scala.collection.mutable @@ -31,7 +32,6 @@ import scala.reflect.ClassTag import scala.reflect.classTag import com.esotericsoftware.kryo.KryoException -import org.apache.commons.lang3.RandomUtils import org.mockito.{ArgumentCaptor, ArgumentMatchers => mc} import org.mockito.Mockito.{doAnswer, mock, never, spy, times, verify, when} import org.scalatest.PrivateMethodTester @@ -1887,7 +1887,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe (transCtx.createServer(port, Seq.empty[TransportServerBootstrap].asJava), port) } - val candidatePort = RandomUtils.nextInt(1024, 65536) + val candidatePort = ThreadLocalRandom.current().nextInt(1024, 65536) val (server, shufflePort) = Utils.startServiceOnPort(candidatePort, newShuffleServer, conf, "ShuffleServer") @@ -2274,7 +2274,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe (transCtx.createServer(port, Seq.empty[TransportServerBootstrap].asJava), port) } - val candidatePort = RandomUtils.nextInt(1024, 65536) + val candidatePort = ThreadLocalRandom.current().nextInt(1024, 65536) val (server, shufflePort) = Utils.startServiceOnPort(candidatePort, newShuffleServer, conf, "ShuffleServer") From 2a23c7a18a0ba75d95ee1d898896a8f0dc2c5531 Mon Sep 17 00:00:00 2001 From: Bo Zhang Date: Mon, 7 Aug 2023 22:10:01 +0500 Subject: [PATCH 52/68] [SPARK-38475][CORE] Use error class in org.apache.spark.serializer ### What changes were proposed in this pull request? This PR aims to change exceptions created in package org.apache.spark.serializer to use error class. ### Why are the changes needed? This is to move exceptions created in package org.apache.spark.serializer to error class. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. Closes #42243 from bozhang2820/spark-38475. Lead-authored-by: Bo Zhang Co-authored-by: Bo Zhang Signed-off-by: Max Gekk --- .../main/resources/error/error-classes.json | 21 +++++++++++++++ .../serializer/GenericAvroSerializer.scala | 6 ++--- .../spark/serializer/KryoSerializer.scala | 27 ++++++++++++++----- docs/sql-error-conditions.md | 24 +++++++++++++++++ 4 files changed, 68 insertions(+), 10 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 680f787429c70..0ea1eed35e463 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -831,6 +831,11 @@ "Not found an encoder of the type to Spark SQL internal representation. Consider to change the input type to one of supported at '/sql-ref-datatypes.html'." ] }, + "ERROR_READING_AVRO_UNKNOWN_FINGERPRINT" : { + "message" : [ + "Error reading avro data -- encountered an unknown fingerprint: , not sure what schema to use. This could happen if you registered additional schemas after starting your spark context." + ] + }, "EVENT_TIME_IS_NOT_ON_TIMESTAMP_TYPE" : { "message" : [ "The event time has the invalid type , but expected \"TIMESTAMP\"." @@ -864,6 +869,11 @@ ], "sqlState" : "22018" }, + "FAILED_REGISTER_CLASS_WITH_KRYO" : { + "message" : [ + "Failed to register classes with Kryo." + ] + }, "FAILED_RENAME_PATH" : { "message" : [ "Failed to rename to as destination already exists." @@ -1564,6 +1574,12 @@ ], "sqlState" : "22032" }, + "INVALID_KRYO_SERIALIZER_BUFFER_SIZE" : { + "message" : [ + "The value of the config \"\" must be less than 2048 MiB, but got MiB." + ], + "sqlState" : "F0000" + }, "INVALID_LAMBDA_FUNCTION_CALL" : { "message" : [ "Invalid lambda function call." @@ -2006,6 +2022,11 @@ "The join condition has the invalid type , expected \"BOOLEAN\"." ] }, + "KRYO_BUFFER_OVERFLOW" : { + "message" : [ + "Kryo serialization failed: . To avoid this, increase \"\" value." + ] + }, "LOAD_DATA_PATH_NOT_EXISTS" : { "message" : [ "LOAD DATA input path does not exist: ." diff --git a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala index 7d2923fdf3752..d09abff2773b8 100644 --- a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala @@ -140,9 +140,9 @@ private[serializer] class GenericAvroSerializer[D <: GenericContainer] case Some(s) => new Schema.Parser().setValidateDefaults(false).parse(s) case None => throw new SparkException( - "Error reading attempting to read avro data -- encountered an unknown " + - s"fingerprint: $fingerprint, not sure what schema to use. This could happen " + - "if you registered additional schemas after starting your spark context.") + errorClass = "ERROR_READING_AVRO_UNKNOWN_FINGERPRINT", + messageParameters = Map("fingerprint" -> fingerprint.toString), + cause = null) } }) } else { diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 826d6789f88ee..f75942cbb879f 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -66,15 +66,21 @@ class KryoSerializer(conf: SparkConf) private val bufferSizeKb = conf.get(KRYO_SERIALIZER_BUFFER_SIZE) if (bufferSizeKb >= ByteUnit.GiB.toKiB(2)) { - throw new IllegalArgumentException(s"${KRYO_SERIALIZER_BUFFER_SIZE.key} must be less than " + - s"2048 MiB, got: + ${ByteUnit.KiB.toMiB(bufferSizeKb)} MiB.") + throw new SparkIllegalArgumentException( + errorClass = "INVALID_KRYO_SERIALIZER_BUFFER_SIZE", + messageParameters = Map( + "bufferSizeConfKey" -> KRYO_SERIALIZER_BUFFER_SIZE.key, + "bufferSizeConfValue" -> ByteUnit.KiB.toMiB(bufferSizeKb).toString)) } private val bufferSize = ByteUnit.KiB.toBytes(bufferSizeKb).toInt val maxBufferSizeMb = conf.get(KRYO_SERIALIZER_MAX_BUFFER_SIZE).toInt if (maxBufferSizeMb >= ByteUnit.GiB.toMiB(2)) { - throw new IllegalArgumentException(s"${KRYO_SERIALIZER_MAX_BUFFER_SIZE.key} must be less " + - s"than 2048 MiB, got: $maxBufferSizeMb MiB.") + throw new SparkIllegalArgumentException( + errorClass = "INVALID_KRYO_SERIALIZER_BUFFER_SIZE", + messageParameters = Map( + "bufferSizeConfKey" -> KRYO_SERIALIZER_MAX_BUFFER_SIZE.key, + "bufferSizeConfValue" -> maxBufferSizeMb.toString)) } private val maxBufferSize = ByteUnit.MiB.toBytes(maxBufferSizeMb).toInt @@ -183,7 +189,10 @@ class KryoSerializer(conf: SparkConf) .foreach { reg => reg.registerClasses(kryo) } } catch { case e: Exception => - throw new SparkException(s"Failed to register classes with Kryo", e) + throw new SparkException( + errorClass = "FAILED_REGISTER_CLASS_WITH_KRYO", + messageParameters = Map.empty, + cause = e) } } @@ -442,8 +451,12 @@ private[spark] class KryoSerializerInstance( kryo.writeClassAndObject(output, t) } catch { case e: KryoException if e.getMessage.startsWith("Buffer overflow") => - throw new SparkException(s"Kryo serialization failed: ${e.getMessage}. To avoid this, " + - s"increase ${KRYO_SERIALIZER_MAX_BUFFER_SIZE.key} value.", e) + throw new SparkException( + errorClass = "KRYO_BUFFER_OVERFLOW", + messageParameters = Map( + "exceptionMsg" -> e.getMessage, + "bufferSizeConfKey" -> KRYO_SERIALIZER_MAX_BUFFER_SIZE.key), + cause = e) } finally { releaseKryo(kryo) } diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index 5609d60f97419..b59bb1789488e 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -484,6 +484,12 @@ SQLSTATE: none assigned Not found an encoder of the type `` to Spark SQL internal representation. Consider to change the input type to one of supported at '``/sql-ref-datatypes.html'. +### ERROR_READING_AVRO_UNKNOWN_FINGERPRINT + +SQLSTATE: none assigned + +Error reading avro data -- encountered an unknown fingerprint: ``, not sure what schema to use. This could happen if you registered additional schemas after starting your spark context. + ### EVENT_TIME_IS_NOT_ON_TIMESTAMP_TYPE SQLSTATE: none assigned @@ -520,6 +526,12 @@ Failed preparing of the function `` for call. Please, double check fun Failed parsing struct: ``. +### FAILED_REGISTER_CLASS_WITH_KRYO + +SQLSTATE: none assigned + +Failed to register classes with Kryo. + ### FAILED_RENAME_PATH [SQLSTATE: 42K04](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) @@ -972,6 +984,12 @@ Cannot convert JSON root field to target Spark type. Input schema `` can only contain STRING as a key type for a MAP. +### INVALID_KRYO_SERIALIZER_BUFFER_SIZE + +SQLSTATE: F0000 + +The value of the config "``" must be less than 2048 MiB, but got `` MiB. + ### [INVALID_LAMBDA_FUNCTION_CALL](sql-error-conditions-invalid-lambda-function-call-error-class.html) SQLSTATE: none assigned @@ -1163,6 +1181,12 @@ SQLSTATE: none assigned The join condition `` has the invalid type ``, expected "BOOLEAN". +### KRYO_BUFFER_OVERFLOW + +SQLSTATE: none assigned + +Kryo serialization failed: ``. To avoid this, increase "``" value. + ### LOAD_DATA_PATH_NOT_EXISTS SQLSTATE: none assigned From f1a161cb39504bd625ea7fa50d2cc72a1a2a59e9 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 7 Aug 2023 11:48:24 -0700 Subject: [PATCH 53/68] [SPARK-44561][PYTHON] Fix AssertionError when converting UDTF output to a complex type ### What changes were proposed in this pull request? Fixes AssertionError when converting UDTF output to a complex type by ignore assertions in `_create_converter_from_pandas` to make Arrow raise an error. ### Why are the changes needed? There is an assertion in `_create_converter_from_pandas`, but it should not be applied for Python UDTF case. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added/modified the related tests. Closes #42310 from ueshin/issues/SPARK-44561/udtf_complex_types. Authored-by: Takuya UESHIN Signed-off-by: Takuya UESHIN --- python/pyspark/sql/pandas/serializers.py | 5 +- python/pyspark/sql/pandas/types.py | 108 ++++++-- .../sql/tests/connect/test_parity_udtf.py | 3 + python/pyspark/sql/tests/test_udtf.py | 247 ++++++++++++++++-- 4 files changed, 314 insertions(+), 49 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index f3037c8b39c86..d1a3babb1fdc0 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -571,7 +571,10 @@ def _create_array(self, series, arrow_type, spark_type=None, arrow_cast=False): dt = spark_type or from_arrow_type(arrow_type, prefer_timestamp_ntz=True) # TODO(SPARK-43579): cache the converter for reuse conv = _create_converter_from_pandas( - dt, timezone=self._timezone, error_on_duplicated_field_names=False + dt, + timezone=self._timezone, + error_on_duplicated_field_names=False, + ignore_unexpected_complex_type_values=True, ) series = conv(series) diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index 533620476041a..b02a003e632cb 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -21,7 +21,7 @@ """ import datetime import itertools -from typing import Any, Callable, List, Optional, Union, TYPE_CHECKING +from typing import Any, Callable, Iterable, List, Optional, Union, TYPE_CHECKING from pyspark.sql.types import ( cast, @@ -750,6 +750,7 @@ def _create_converter_from_pandas( *, timezone: Optional[str], error_on_duplicated_field_names: bool = True, + ignore_unexpected_complex_type_values: bool = False, ) -> Callable[["pd.Series"], "pd.Series"]: """ Create a converter of pandas Series to create Spark DataFrame with Arrow optimization. @@ -763,6 +764,17 @@ def _create_converter_from_pandas( error_on_duplicated_field_names : bool, optional Whether raise an exception when there are duplicated field names. (default ``True``) + ignore_unexpected_complex_type_values : bool, optional + Whether ignore the case where unexpected values are given for complex types. + If ``False``, each complex type expects: + + * array type: :class:`Iterable` + * map type: :class:`dict` + * struct type: :class:`dict` or :class:`tuple` + + and raise an AssertionError when the given value is not the expected type. + If ``True``, just ignore and return the give value. + (default ``False``) Returns ------- @@ -781,15 +793,26 @@ def correct_timestamp(pser: pd.Series) -> pd.Series: def _converter(dt: DataType) -> Optional[Callable[[Any], Any]]: if isinstance(dt, ArrayType): - _element_conv = _converter(dt.elementType) - if _element_conv is None: - return None + _element_conv = _converter(dt.elementType) or (lambda x: x) - def convert_array(value: Any) -> Any: - if value is None: - return None - else: - return [_element_conv(v) for v in value] # type: ignore[misc] + if ignore_unexpected_complex_type_values: + + def convert_array(value: Any) -> Any: + if value is None: + return None + elif isinstance(value, Iterable): + return [_element_conv(v) for v in value] + else: + return value + + else: + + def convert_array(value: Any) -> Any: + if value is None: + return None + else: + assert isinstance(value, Iterable) + return [_element_conv(v) for v in value] return convert_array @@ -797,12 +820,24 @@ def convert_array(value: Any) -> Any: _key_conv = _converter(dt.keyType) or (lambda x: x) _value_conv = _converter(dt.valueType) or (lambda x: x) - def convert_map(value: Any) -> Any: - if value is None: - return None - else: - assert isinstance(value, dict) - return [(_key_conv(k), _value_conv(v)) for k, v in value.items()] + if ignore_unexpected_complex_type_values: + + def convert_map(value: Any) -> Any: + if value is None: + return None + elif isinstance(value, dict): + return [(_key_conv(k), _value_conv(v)) for k, v in value.items()] + else: + return value + + else: + + def convert_map(value: Any) -> Any: + if value is None: + return None + else: + assert isinstance(value, dict) + return [(_key_conv(k), _value_conv(v)) for k, v in value.items()] return convert_map @@ -820,17 +855,38 @@ def convert_map(value: Any) -> Any: field_convs = [_converter(f.dataType) or (lambda x: x) for f in dt.fields] - def convert_struct(value: Any) -> Any: - if value is None: - return None - elif isinstance(value, dict): - return { - dedup_field_names[i]: field_convs[i](value.get(key, None)) - for i, key in enumerate(field_names) - } - else: - assert isinstance(value, tuple) - return {dedup_field_names[i]: field_convs[i](v) for i, v in enumerate(value)} + if ignore_unexpected_complex_type_values: + + def convert_struct(value: Any) -> Any: + if value is None: + return None + elif isinstance(value, dict): + return { + dedup_field_names[i]: field_convs[i](value.get(key, None)) + for i, key in enumerate(field_names) + } + elif isinstance(value, tuple): + return { + dedup_field_names[i]: field_convs[i](v) for i, v in enumerate(value) + } + else: + return value + + else: + + def convert_struct(value: Any) -> Any: + if value is None: + return None + elif isinstance(value, dict): + return { + dedup_field_names[i]: field_convs[i](value.get(key, None)) + for i, key in enumerate(field_names) + } + else: + assert isinstance(value, tuple) + return { + dedup_field_names[i]: field_convs[i](v) for i, v in enumerate(value) + } return convert_struct diff --git a/python/pyspark/sql/tests/connect/test_parity_udtf.py b/python/pyspark/sql/tests/connect/test_parity_udtf.py index 748b611e66707..e12e697e582da 100644 --- a/python/pyspark/sql/tests/connect/test_parity_udtf.py +++ b/python/pyspark/sql/tests/connect/test_parity_udtf.py @@ -45,6 +45,9 @@ def tearDownClass(cls): # TODO: use PySpark error classes instead of SparkConnectGrpcException + def test_struct_output_type_casting_row(self): + self.check_struct_output_type_casting_row(SparkConnectGrpcException) + def test_udtf_with_invalid_return_type(self): @udtf(returnType="int") class TestUDTF: diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 9caf267e48df3..b2f473996bcb6 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -18,9 +18,10 @@ import shutil import tempfile import unittest - from typing import Iterator +from py4j.protocol import Py4JJavaError + from pyspark.errors import ( PySparkAttributeError, PythonException, @@ -582,12 +583,14 @@ def eval(self): assertDataFrameEqual(TestUDTF(), [Row()]) - def _check_result_or_exception(self, func_handler, ret_type, expected): + def _check_result_or_exception( + self, func_handler, ret_type, expected, *, err_type=PythonException + ): func = udtf(func_handler, returnType=ret_type) if not isinstance(expected, str): assertDataFrameEqual(func(), expected) else: - with self.assertRaisesRegex(PythonException, expected): + with self.assertRaisesRegex(err_type, expected): func().collect() def test_numeric_output_type_casting(self): @@ -679,20 +682,129 @@ def eval(self): def test_array_output_type_casting(self): class TestUDTF: def eval(self): - yield [1, 2], + yield [0, 1.1, 2], for ret_type, expected in [ + ("x: boolean", [Row(x=None)]), + ("x: tinyint", [Row(x=None)]), + ("x: smallint", [Row(x=None)]), ("x: int", [Row(x=None)]), - ("x: array", [Row(x=[1, 2])]), - ("x: array", [Row(x=[None, None])]), - ("x: array", [Row(x=["1", "2"])]), - ("x: array", [Row(x=[None, None])]), - ("x: array>", [Row(x=[None, None])]), + ("x: bigint", [Row(x=None)]), + ("x: string", [Row(x="[0, 1.1, 2]")]), + ("x: date", "AttributeError"), + ("x: timestamp", "AttributeError"), + ("x: byte", [Row(x=None)]), + ("x: binary", [Row(x=None)]), + ("x: float", [Row(x=None)]), + ("x: double", [Row(x=None)]), + ("x: decimal(10, 0)", [Row(x=None)]), + ("x: array", [Row(x=[0, None, 2])]), + ("x: array", [Row(x=[None, 1.1, None])]), + ("x: array", [Row(x=["0", "1.1", "2"])]), + ("x: array", [Row(x=[None, None, None])]), + ("x: array>", [Row(x=[None, None, None])]), ("x: map", [Row(x=None)]), + ("x: struct", [Row(x=Row(a=0, b=None, c=2))]), + ]: + with self.subTest(ret_type=ret_type): + self._check_result_or_exception(TestUDTF, ret_type, expected) + + def test_map_output_type_casting(self): + class TestUDTF: + def eval(self): + yield {"a": 0, "b": 1.1, "c": 2}, + + for ret_type, expected in [ + ("x: boolean", [Row(x=None)]), + ("x: tinyint", [Row(x=None)]), + ("x: smallint", [Row(x=None)]), + ("x: int", [Row(x=None)]), + ("x: bigint", [Row(x=None)]), + ("x: string", [Row(x="{a=0, b=1.1, c=2}")]), + ("x: date", "AttributeError"), + ("x: timestamp", "AttributeError"), + ("x: byte", [Row(x=None)]), + ("x: binary", [Row(x=None)]), + ("x: float", [Row(x=None)]), + ("x: double", [Row(x=None)]), + ("x: decimal(10, 0)", [Row(x=None)]), + ("x: array", [Row(x=None)]), + ("x: map", [Row(x={"a": "0", "b": "1.1", "c": "2"})]), + ("x: map", [Row(x={"a": None, "b": None, "c": None})]), + ("x: map", [Row(x={"a": 0, "b": None, "c": 2})]), + ("x: map", [Row(x={"a": None, "b": 1.1, "c": None})]), + ("x: map>", [Row(x={"a": None, "b": None, "c": None})]), + ("x: struct", [Row(x=Row(a=0))]), + ]: + with self.subTest(ret_type=ret_type): + self._check_result_or_exception(TestUDTF, ret_type, expected) + + def test_struct_output_type_casting_dict(self): + class TestUDTF: + def eval(self): + yield {"a": 0, "b": 1.1, "c": 2}, + + for ret_type, expected in [ + ("x: boolean", [Row(x=None)]), + ("x: tinyint", [Row(x=None)]), + ("x: smallint", [Row(x=None)]), + ("x: int", [Row(x=None)]), + ("x: bigint", [Row(x=None)]), + ("x: string", [Row(x="{a=0, b=1.1, c=2}")]), + ("x: date", "AttributeError"), + ("x: timestamp", "AttributeError"), + ("x: byte", [Row(x=None)]), + ("x: binary", [Row(x=None)]), + ("x: float", [Row(x=None)]), + ("x: double", [Row(x=None)]), + ("x: decimal(10, 0)", [Row(x=None)]), + ("x: array", [Row(x=None)]), + ("x: map", [Row(x={"a": "0", "b": "1.1", "c": "2"})]), + ("x: struct", [Row(Row(a="0", b="1.1", c="2"))]), + ("x: struct", [Row(Row(a=0, b=None, c=2))]), + ("x: struct", [Row(Row(a=None, b=1.1, c=None))]), ]: with self.subTest(ret_type=ret_type): self._check_result_or_exception(TestUDTF, ret_type, expected) + def test_struct_output_type_casting_row(self): + self.check_struct_output_type_casting_row(Py4JJavaError) + + def check_struct_output_type_casting_row(self, error_type): + class TestUDTF: + def eval(self): + yield Row(a=0, b=1.1, c=2), + + err = ("PickleException", error_type) + + for ret_type, expected in [ + ("x: boolean", err), + ("x: tinyint", err), + ("x: smallint", err), + ("x: int", err), + ("x: bigint", err), + ("x: string", err), + ("x: date", "ValueError"), + ("x: timestamp", "ValueError"), + ("x: byte", err), + ("x: binary", err), + ("x: float", err), + ("x: double", err), + ("x: decimal(10, 0)", err), + ("x: array", err), + ("x: map", err), + ("x: struct", [Row(Row(a="0", b="1.1", c="2"))]), + ("x: struct", [Row(Row(a=0, b=None, c=2))]), + ("x: struct", [Row(Row(a=None, b=1.1, c=None))]), + ]: + with self.subTest(ret_type=ret_type): + if isinstance(expected, tuple): + self._check_result_or_exception( + TestUDTF, ret_type, expected[0], err_type=expected[1] + ) + else: + self._check_result_or_exception(TestUDTF, ret_type, expected) + def test_inconsistent_output_types(self): class TestUDTF: def eval(self): @@ -1777,9 +1889,8 @@ def eval(self): ("x: double", [Row(x=1.0)]), ("x: decimal(10, 0)", err), ("x: array", err), - # TODO(SPARK-44561): fix AssertionError in convert_map and convert_struct - # ("x: map", None), - # ("x: struct", None) + ("x: map", err), + ("x: struct", err), ]: with self.subTest(ret_type=ret_type): self._check_result_or_exception(TestUDTF, ret_type, expected) @@ -1806,10 +1917,9 @@ def eval(self): ("x: double", [Row(x=1.0)]), ("x: decimal(10, 0)", [Row(x=1)]), ("x: array", [Row(x=["1"])]), - ("x: array", err), - # TODO(SPARK-44561): fix AssertionError in convert_map and convert_struct - # ("x: map", None), - # ("x: struct", None) + ("x: array", [Row(x=[1])]), + ("x: map", err), + ("x: struct", err), ]: with self.subTest(ret_type=ret_type): self._check_result_or_exception(TestUDTF, ret_type, expected) @@ -1837,9 +1947,8 @@ def eval(self): ("x: decimal(10, 0)", err), ("x: array", [Row(x=["h", "e", "l", "l", "o"])]), ("x: array", err), - # TODO(SPARK-44561): fix AssertionError in convert_map and convert_struct - # ("x: map", None), - # ("x: struct", None) + ("x: map", err), + ("x: struct", err), ]: with self.subTest(ret_type=ret_type): self._check_result_or_exception(TestUDTF, ret_type, expected) @@ -1870,9 +1979,103 @@ def eval(self): ("x: array", [Row(x=[0, 1, 2])]), ("x: array", [Row(x=[0, 1.1, 2])]), ("x: array>", err), - # TODO(SPARK-44561): fix AssertionError in convert_map and convert_struct - # ("x: map", None), - # ("x: struct", None) + ("x: map", err), + ("x: struct", err), + ("x: struct", err), + ]: + with self.subTest(ret_type=ret_type): + self._check_result_or_exception(TestUDTF, ret_type, expected) + + def test_map_output_type_casting(self): + class TestUDTF: + def eval(self): + yield {"a": 0, "b": 1.1, "c": 2}, + + err = "UDTF_ARROW_TYPE_CAST_ERROR" + + for ret_type, expected in [ + ("x: boolean", err), + ("x: tinyint", err), + ("x: smallint", err), + ("x: int", err), + ("x: bigint", err), + ("x: string", err), + ("x: date", err), + ("x: timestamp", err), + ("x: byte", err), + ("x: binary", err), + ("x: float", err), + ("x: double", err), + ("x: decimal(10, 0)", err), + ("x: array", [Row(x=["a", "b", "c"])]), + ("x: map", err), + ("x: map", err), + ("x: map", [Row(x={"a": 0, "b": 1, "c": 2})]), + ("x: map", [Row(x={"a": 0, "b": 1.1, "c": 2})]), + ("x: map>", err), + ("x: struct", [Row(x=Row(a=0))]), + ]: + with self.subTest(ret_type=ret_type): + self._check_result_or_exception(TestUDTF, ret_type, expected) + + def test_struct_output_type_casting_dict(self): + class TestUDTF: + def eval(self): + yield {"a": 0, "b": 1.1, "c": 2}, + + err = "UDTF_ARROW_TYPE_CAST_ERROR" + + for ret_type, expected in [ + ("x: boolean", err), + ("x: tinyint", err), + ("x: smallint", err), + ("x: int", err), + ("x: bigint", err), + ("x: string", err), + ("x: date", err), + ("x: timestamp", err), + ("x: byte", err), + ("x: binary", err), + ("x: float", err), + ("x: double", err), + ("x: decimal(10, 0)", err), + ("x: array", [Row(x=["a", "b", "c"])]), + ("x: map", err), + ("x: struct", [Row(Row(a="0", b="1.1", c="2"))]), + ("x: struct", [Row(Row(a=0, b=1, c=2))]), + ("x: struct", [Row(Row(a=0, b=1.1, c=2))]), + ("x: struct,b:struct<>,c:struct<>>", err), + ]: + with self.subTest(ret_type=ret_type): + self._check_result_or_exception(TestUDTF, ret_type, expected) + + def test_struct_output_type_casting_row(self): + class TestUDTF: + def eval(self): + yield Row(a=0, b=1.1, c=2), + + err = "UDTF_ARROW_TYPE_CAST_ERROR" + + for ret_type, expected in [ + ("x: boolean", err), + ("x: tinyint", err), + ("x: smallint", err), + ("x: int", err), + ("x: bigint", err), + ("x: string", err), + ("x: date", err), + ("x: timestamp", err), + ("x: byte", err), + ("x: binary", err), + ("x: float", err), + ("x: double", err), + ("x: decimal(10, 0)", err), + ("x: array", [Row(x=["0", "1.1", "2"])]), + ("x: map", err), + ("x: struct", [Row(Row(a="0", b="1.1", c="2"))]), + ("x: struct", [Row(Row(a=0, b=1, c=2))]), + ("x: struct", [Row(Row(a=0, b=1.1, c=2))]), + ("x: struct,b:struct<>,c:struct<>>", err), ]: with self.subTest(ret_type=ret_type): self._check_result_or_exception(TestUDTF, ret_type, expected) From 726ccb532a1b5dbe9c55e68a71d3125570c6738d Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 7 Aug 2023 13:56:24 -0700 Subject: [PATCH 54/68] [SPARK-44707][K8S] Use INFO log in `ExecutorPodsWatcher.onClose` if `SparkContext` is stopped ### What changes were proposed in this pull request? This PR is a minor log change which aims to use `INFO`-level log instead of `WARN`-level in `ExecutorPodsWatcher.onClose` if `SparkContext` is stopped. Since Spark can distinguish the expected behavior from the error cases, Spark had better avoid WARNING. ### Why are the changes needed? Previously, we have `WARN ExecutorPodsWatchSnapshotSource: Kubernetes client has been closed` message. ``` 23/08/07 18:10:14 INFO SparkContext: SparkContext is stopping with exitCode 0. 23/08/07 18:10:14 WARN TaskSetManager: Lost task 2594.0 in stage 0.0 (TID 2594) ([2620:149:100d:1813::3f86] executor 1615): TaskKilled (another attempt succeeded) 23/08/07 18:10:14 INFO TaskSetManager: task 2594.0 in stage 0.0 (TID 2594) failed, but the task will not be re-executed (either because the task failed with a shuffle data fetch failure, so the previous stage needs to be re-run, or because a different copy of the task has already succeeded). 23/08/07 18:10:14 INFO SparkUI: Stopped Spark web UI at http://xxx:4040 23/08/07 18:10:14 INFO KubernetesClusterSchedulerBackend: Shutting down all executors 23/08/07 18:10:14 INFO KubernetesClusterSchedulerBackend$KubernetesDriverEndpoint: Asking each executor to shut down 23/08/07 18:10:14 WARN ExecutorPodsWatchSnapshotSource: Kubernetes client has been closed. ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs. Closes #42381 from dongjoon-hyun/SPARK-44707. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../k8s/ExecutorPodsWatchSnapshotSource.scala | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala index 4809222650d82..6953ed789f797 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala @@ -86,12 +86,20 @@ class ExecutorPodsWatchSnapshotSource( } override def onClose(e: WatcherException): Unit = { - logWarning("Kubernetes client has been closed (this is expected if the application is" + - " shutting down.)", e) + if (SparkContext.getActive.map(_.isStopped).getOrElse(true)) { + logInfo("Kubernetes client has been closed.") + } else { + logWarning("Kubernetes client has been closed (this is expected if the application is" + + " shutting down.)", e) + } } override def onClose(): Unit = { - logWarning("Kubernetes client has been closed.") + if (SparkContext.getActive.map(_.isStopped).getOrElse(true)) { + logInfo("Kubernetes client has been closed.") + } else { + logWarning("Kubernetes client has been closed.") + } } } From d5d3f393f16d1c17f88857b81e9bd7573d594d87 Mon Sep 17 00:00:00 2001 From: itholic Date: Tue, 8 Aug 2023 06:31:52 +0900 Subject: [PATCH 55/68] [SPARK-43606][PS] Remove `Int64Index` & `Float64Index` ### What changes were proposed in this pull request? This PR proposes to remove `Int64Index` & `Float64Index` from pandas API on Spark. ### Why are the changes needed? To match the behavior with pandas 2 and above. ### Does this PR introduce _any_ user-facing change? Yes, the `Int64Index` & `Float64Index` will be removed. ### How was this patch tested? Enabling the existing doctests & UTs. Closes #42267 from itholic/SPARK-43245. Authored-by: itholic Signed-off-by: Hyukjin Kwon --- dev/sparktestsupport/modules.py | 1 - .../migration_guide/pyspark_upgrade.rst | 1 + .../reference/pyspark.pandas/indexing.rst | 10 - python/pyspark/pandas/__init__.py | 3 - python/pyspark/pandas/base.py | 28 +-- python/pyspark/pandas/frame.py | 8 +- python/pyspark/pandas/indexes/__init__.py | 1 - python/pyspark/pandas/indexes/base.py | 159 ++++++------- python/pyspark/pandas/indexes/category.py | 4 +- python/pyspark/pandas/indexes/datetimes.py | 12 +- python/pyspark/pandas/indexes/numeric.py | 210 ------------------ python/pyspark/pandas/series.py | 4 +- python/pyspark/pandas/spark/accessors.py | 16 +- .../pyspark/pandas/tests/indexes/test_base.py | 45 +--- .../pandas/tests/series/test_compute.py | 4 +- .../pandas/tests/series/test_series.py | 3 +- .../pyspark/pandas/usage_logging/__init__.py | 3 - 17 files changed, 124 insertions(+), 388 deletions(-) delete mode 100644 python/pyspark/pandas/indexes/numeric.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index b2f978c47ea30..c5be1957a7dcb 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -668,7 +668,6 @@ def __hash__(self): "pyspark.pandas.indexes.datetimes", "pyspark.pandas.indexes.timedelta", "pyspark.pandas.indexes.multi", - "pyspark.pandas.indexes.numeric", "pyspark.pandas.spark.accessors", "pyspark.pandas.spark.utils", "pyspark.pandas.typedef.typehints", diff --git a/python/docs/source/migration_guide/pyspark_upgrade.rst b/python/docs/source/migration_guide/pyspark_upgrade.rst index 9bd879fb1a1a6..7a691ee264571 100644 --- a/python/docs/source/migration_guide/pyspark_upgrade.rst +++ b/python/docs/source/migration_guide/pyspark_upgrade.rst @@ -22,6 +22,7 @@ Upgrading PySpark Upgrading from PySpark 3.5 to 4.0 --------------------------------- +* In Spark 4.0, ``Int64Index`` and ``Float64Index`` have been removed from pandas API on Spark, ``Index`` should be used directly. * In Spark 4.0, ``DataFrame.iteritems`` has been removed from pandas API on Spark, use ``DataFrame.items`` instead. * In Spark 4.0, ``Series.iteritems`` has been removed from pandas API on Spark, use ``Series.items`` instead. * In Spark 4.0, ``DataFrame.append`` has been removed from pandas API on Spark, use ``ps.concat`` instead. diff --git a/python/docs/source/reference/pyspark.pandas/indexing.rst b/python/docs/source/reference/pyspark.pandas/indexing.rst index 15539fa226633..70d463c052a03 100644 --- a/python/docs/source/reference/pyspark.pandas/indexing.rst +++ b/python/docs/source/reference/pyspark.pandas/indexing.rst @@ -166,16 +166,6 @@ Selecting Index.asof Index.isin -.. _api.numeric: - -Numeric Index -------------- -.. autosummary:: - :toctree: api/ - - Int64Index - Float64Index - .. _api.categorical: CategoricalIndex diff --git a/python/pyspark/pandas/__init__.py b/python/pyspark/pandas/__init__.py index 980aeab2bee87..d8ce385639cec 100644 --- a/python/pyspark/pandas/__init__.py +++ b/python/pyspark/pandas/__init__.py @@ -61,7 +61,6 @@ from pyspark.pandas.indexes.category import CategoricalIndex from pyspark.pandas.indexes.datetimes import DatetimeIndex from pyspark.pandas.indexes.multi import MultiIndex -from pyspark.pandas.indexes.numeric import Float64Index, Int64Index from pyspark.pandas.indexes.timedelta import TimedeltaIndex from pyspark.pandas.series import Series from pyspark.pandas.groupby import NamedAgg @@ -77,8 +76,6 @@ "Series", "Index", "MultiIndex", - "Int64Index", - "Float64Index", "CategoricalIndex", "DatetimeIndex", "TimedeltaIndex", diff --git a/python/pyspark/pandas/base.py b/python/pyspark/pandas/base.py index e005fd19b3009..2de260e6e9351 100644 --- a/python/pyspark/pandas/base.py +++ b/python/pyspark/pandas/base.py @@ -904,8 +904,8 @@ def astype(self: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike: 1 2 dtype: int64 - >>> ser.rename("a").to_frame().set_index("a").index.astype('int64') # doctest: +SKIP - Int64Index([1, 2], dtype='int64', name='a') + >>> ser.rename("a").to_frame().set_index("a").index.astype('int64') + Index([1, 2], dtype='int64', name='a') """ return self._dtype_op.astype(self, dtype) @@ -1247,8 +1247,8 @@ def shift( 4 23 Name: Col2, dtype: int64 - >>> df.index.shift(periods=3, fill_value=0) # doctest: +SKIP - Int64Index([0, 0, 0, 0, 1], dtype='int64') + >>> df.index.shift(periods=3, fill_value=0) + Index([0, 0, 0, 0, 1], dtype='int64') """ return self._shift(periods, fill_value).spark.analyzed @@ -1341,8 +1341,8 @@ def value_counts( For Index >>> idx = ps.Index([3, 1, 2, 3, 4, np.nan]) - >>> idx # doctest: +SKIP - Float64Index([3.0, 1.0, 2.0, 3.0, 4.0, nan], dtype='float64') + >>> idx + Index([3.0, 1.0, 2.0, 3.0, 4.0, nan], dtype='float64') >>> idx.value_counts().sort_index() 1.0 1 @@ -1511,8 +1511,8 @@ def nunique(self, dropna: bool = True, approx: bool = False, rsd: float = 0.05) 3 >>> idx = ps.Index([1, 1, 2, None]) - >>> idx # doctest: +SKIP - Float64Index([1.0, 1.0, 2.0, nan], dtype='float64') + >>> idx + Index([1.0, 1.0, 2.0, nan], dtype='float64') >>> idx.nunique() 2 @@ -1586,11 +1586,11 @@ def take(self: IndexOpsLike, indices: Sequence[int]) -> IndexOpsLike: Index >>> psidx = ps.Index([100, 200, 300, 400, 500]) - >>> psidx # doctest: +SKIP - Int64Index([100, 200, 300, 400, 500], dtype='int64') + >>> psidx + Index([100, 200, 300, 400, 500], dtype='int64') - >>> psidx.take([0, 2, 4]).sort_values() # doctest: +SKIP - Int64Index([100, 300, 500], dtype='int64') + >>> psidx.take([0, 2, 4]).sort_values() + Index([100, 300, 500], dtype='int64') MultiIndex @@ -1684,8 +1684,8 @@ def factorize( >>> psidx = ps.Index(['b', None, 'a', 'c', 'b']) >>> codes, uniques = psidx.factorize() - >>> codes # doctest: +SKIP - Int64Index([1, -1, 0, 2, 1], dtype='int64') + >>> codes + Index([1, -1, 0, 2, 1], dtype='int32') >>> uniques Index(['a', 'b', 'c'], dtype='object') """ diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index b960b3444e319..72d4a88b69203 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -734,8 +734,8 @@ def axes(self) -> List: -------- >>> df = ps.DataFrame({'col1': [1, 2], 'col2': [3, 4]}) - >>> df.axes # doctest: +SKIP - [Int64Index([0, 1], dtype='int64'), Index(['col1', 'col2'], dtype='object')] + >>> df.axes + [Index([0, 1], dtype='int64'), Index(['col1', 'col2'], dtype='object')] """ return [self.index, self.columns] @@ -8707,8 +8707,8 @@ def join( the original DataFrame’s index in the result unlike pandas. >>> join_psdf = psdf1.join(psdf2.set_index('key'), on='key') - >>> join_psdf.index # doctest: +SKIP - Int64Index([0, 1, 2, 3], dtype='int64') + >>> join_psdf.index + Index([0, 1, 2, 3], dtype='int64') """ if isinstance(right, ps.Series): common = list(self.columns.intersection([right.name])) diff --git a/python/pyspark/pandas/indexes/__init__.py b/python/pyspark/pandas/indexes/__init__.py index 7fde6ffaf61da..0193d366024cd 100644 --- a/python/pyspark/pandas/indexes/__init__.py +++ b/python/pyspark/pandas/indexes/__init__.py @@ -17,5 +17,4 @@ from pyspark.pandas.indexes.base import Index # noqa: F401 from pyspark.pandas.indexes.datetimes import DatetimeIndex # noqa: F401 from pyspark.pandas.indexes.multi import MultiIndex # noqa: F401 -from pyspark.pandas.indexes.numeric import Float64Index, Int64Index # noqa: F401 from pyspark.pandas.indexes.timedelta import TimedeltaIndex # noqa: F401 diff --git a/python/pyspark/pandas/indexes/base.py b/python/pyspark/pandas/indexes/base.py index a8fd07aa2a73d..4c2ab13743592 100644 --- a/python/pyspark/pandas/indexes/base.py +++ b/python/pyspark/pandas/indexes/base.py @@ -51,7 +51,6 @@ from pyspark.sql import functions as F from pyspark.sql.types import ( DayTimeIntervalType, - FractionalType, IntegralType, TimestampType, TimestampNTZType, @@ -112,19 +111,17 @@ class Index(IndexOpsMixin): -------- MultiIndex : A multi-level, or hierarchical, Index. DatetimeIndex : Index of datetime64 data. - Int64Index : A special case of :class:`Index` with purely integer labels. - Float64Index : A special case of :class:`Index` with purely float labels. Examples -------- - >>> ps.DataFrame({'a': ['a', 'b', 'c']}, index=[1, 2, 3]).index # doctest: +SKIP - Int64Index([1, 2, 3], dtype='int64') + >>> ps.DataFrame({'a': ['a', 'b', 'c']}, index=[1, 2, 3]).index + Index([1, 2, 3], dtype='int64') - >>> ps.DataFrame({'a': [1, 2, 3]}, index=list('abc')).index # doctest: +SKIP + >>> ps.DataFrame({'a': [1, 2, 3]}, index=list('abc')).index Index(['a', 'b', 'c'], dtype='object') - >>> ps.Index([1, 2, 3]) # doctest: +SKIP - Int64Index([1, 2, 3], dtype='int64') + >>> ps.Index([1, 2, 3]) + Index([1, 2, 3], dtype='int64') >>> ps.Index(list('abc')) Index(['a', 'b', 'c'], dtype='object') @@ -132,14 +129,14 @@ class Index(IndexOpsMixin): From a Series: >>> s = ps.Series([1, 2, 3], index=[10, 20, 30]) - >>> ps.Index(s) # doctest: +SKIP - Int64Index([1, 2, 3], dtype='int64') + >>> ps.Index(s) + Index([1, 2, 3], dtype='int64') From an Index: >>> idx = ps.Index([1, 2, 3]) - >>> ps.Index(idx) # doctest: +SKIP - Int64Index([1, 2, 3], dtype='int64') + >>> ps.Index(idx) + Index([1, 2, 3], dtype='int64') """ def __new__( @@ -198,7 +195,6 @@ def _new_instance(anchor: DataFrame) -> "Index": from pyspark.pandas.indexes.category import CategoricalIndex from pyspark.pandas.indexes.datetimes import DatetimeIndex from pyspark.pandas.indexes.multi import MultiIndex - from pyspark.pandas.indexes.numeric import Float64Index, Int64Index from pyspark.pandas.indexes.timedelta import TimedeltaIndex instance: Index @@ -206,14 +202,6 @@ def _new_instance(anchor: DataFrame) -> "Index": instance = object.__new__(MultiIndex) elif isinstance(anchor._internal.index_fields[0].dtype, CategoricalDtype): instance = object.__new__(CategoricalIndex) - elif isinstance( - anchor._internal.spark_type_for(anchor._internal.index_spark_columns[0]), IntegralType - ): - instance = object.__new__(Int64Index) - elif isinstance( - anchor._internal.spark_type_for(anchor._internal.index_spark_columns[0]), FractionalType - ): - instance = object.__new__(Float64Index) elif isinstance( anchor._internal.spark_type_for(anchor._internal.index_spark_columns[0]), (TimestampType, TimestampNTZType), @@ -800,8 +788,8 @@ def rename(self, name: Union[Name, List[Name]], inplace: bool = False) -> Option Examples -------- >>> df = ps.DataFrame({'a': ['A', 'C'], 'b': ['A', 'B']}, columns=['a', 'b']) - >>> df.index.rename("c") # doctest: +SKIP - Int64Index([0, 1], dtype='int64', name='c') + >>> df.index.rename("c") + Index([0, 1], dtype='int64', name='c') >>> df.set_index("a", inplace=True) >>> df.index.rename("d") @@ -869,11 +857,11 @@ def fillna(self, value: Scalar) -> "Index": Examples -------- >>> idx = ps.Index([1, 2, None]) - >>> idx # doctest: +SKIP - Float64Index([1.0, 2.0, nan], dtype='float64') + >>> idx + Index([1.0, 2.0, nan], dtype='float64') - >>> idx.fillna(0) # doctest: +SKIP - Float64Index([1.0, 2.0, 0.0], dtype='float64') + >>> idx.fillna(0) + Index([1.0, 2.0, 0.0], dtype='float64') """ if not isinstance(value, (float, int, str, bool)): raise TypeError("Unsupported type %s" % type(value).__name__) @@ -1242,8 +1230,7 @@ def unique(self, level: Optional[Union[int, Name]] = None) -> "Index": Examples -------- >>> ps.DataFrame({'a': ['a', 'b', 'c']}, index=[1, 1, 3]).index.unique().sort_values() - ... # doctest: +SKIP - Int64Index([1, 3], dtype='int64') + Index([1, 3], dtype='int64') >>> ps.DataFrame({'a': ['a', 'b', 'c']}, index=['d', 'e', 'e']).index.unique().sort_values() Index(['d', 'e'], dtype='object') @@ -1287,11 +1274,11 @@ def drop(self, labels: List[Any]) -> "Index": Examples -------- >>> index = ps.Index([1, 2, 3]) - >>> index # doctest: +SKIP - Int64Index([1, 2, 3], dtype='int64') + >>> index + Index([1, 2, 3], dtype='int64') - >>> index.drop([1]) # doctest: +SKIP - Int64Index([2, 3], dtype='int64') + >>> index.drop([1]) + Index([2, 3], dtype='int64') """ internal = self._internal.resolved_copy sdf = internal.spark_frame[~internal.index_spark_columns[0].isin(labels)] @@ -1406,8 +1393,8 @@ def droplevel(self, level: Union[int, Name, List[Union[int, Name]]]) -> "Index": MultiIndex([('a', 'b', 1), ('x', 'y', 2)], ) - >>> midx.droplevel([0, 1]) # doctest: +SKIP - Int64Index([1, 2], dtype='int64') + >>> midx.droplevel([0, 1]) + Index([1, 2], dtype='int64') >>> midx.droplevel(0) # doctest: +SKIP MultiIndex([('b', 1), ('y', 2)], @@ -1510,23 +1497,23 @@ def symmetric_difference( >>> s1 = ps.Series([1, 2, 3, 4], index=[1, 2, 3, 4]) >>> s2 = ps.Series([1, 2, 3, 4], index=[2, 3, 4, 5]) - >>> s1.index.symmetric_difference(s2.index) # doctest: +SKIP - Int64Index([5, 1], dtype='int64') + >>> s1.index.symmetric_difference(s2.index) + Index([1, 5], dtype='int64') You can set name of result Index. - >>> s1.index.symmetric_difference(s2.index, result_name='pandas-on-Spark') # doctest: +SKIP - Int64Index([5, 1], dtype='int64', name='pandas-on-Spark') + >>> s1.index.symmetric_difference(s2.index, result_name='pandas-on-Spark') + Index([1, 5], dtype='int64', name='pandas-on-Spark') You can set sort to `True`, if you want to sort the resulting index. - >>> s1.index.symmetric_difference(s2.index, sort=True) # doctest: +SKIP - Int64Index([1, 5], dtype='int64') + >>> s1.index.symmetric_difference(s2.index, sort=True) + Index([1, 5], dtype='int64') You can also use the ``^`` operator: - >>> s1.index ^ s2.index # doctest: +SKIP - Int64Index([5, 1], dtype='int64') + >>> (s1.index ^ s2.index) + Index([1, 5], dtype='int64') """ if type(self) != type(other): raise NotImplementedError( @@ -1592,23 +1579,23 @@ def sort_values( Examples -------- >>> idx = ps.Index([10, 100, 1, 1000]) - >>> idx # doctest: +SKIP - Int64Index([10, 100, 1, 1000], dtype='int64') + >>> idx + Index([10, 100, 1, 1000], dtype='int64') Sort values in ascending order (default behavior). - >>> idx.sort_values() # doctest: +SKIP - Int64Index([1, 10, 100, 1000], dtype='int64') + >>> idx.sort_values() + Index([1, 10, 100, 1000], dtype='int64') Sort values in descending order. - >>> idx.sort_values(ascending=False) # doctest: +SKIP - Int64Index([1000, 100, 10, 1], dtype='int64') + >>> idx.sort_values(ascending=False) + Index([1000, 100, 10, 1], dtype='int64') Sort values in descending order, and also get the indices idx was sorted by. - >>> idx.sort_values(ascending=False, return_indexer=True) # doctest: +SKIP - (Int64Index([1000, 100, 10, 1], dtype='int64'), Int64Index([3, 1, 0, 2], dtype='int64')) + >>> idx.sort_values(ascending=False, return_indexer=True) + (Index([1000, 100, 10, 1], dtype='int64'), Index([3, 1, 0, 2], dtype='int64')) Support for MultiIndex. @@ -1631,11 +1618,11 @@ def sort_values( ('a', 'x', 1)], ) - >>> psidx.sort_values(ascending=False, return_indexer=True) # doctest: +SKIP + >>> psidx.sort_values(ascending=False, return_indexer=True) (MultiIndex([('c', 'y', 2), ('b', 'z', 3), ('a', 'x', 1)], - ), Int64Index([1, 2, 0], dtype='int64')) + ), Index([1, 2, 0], dtype='int64')) """ sdf = self._internal.spark_frame if return_indexer: @@ -1772,14 +1759,14 @@ def delete(self, loc: Union[int, List[int]]) -> "Index": Examples -------- >>> psidx = ps.Index([10, 10, 9, 8, 4, 2, 4, 4, 2, 2, 10, 10]) - >>> psidx # doctest: +SKIP - Int64Index([10, 10, 9, 8, 4, 2, 4, 4, 2, 2, 10, 10], dtype='int64') + >>> psidx + Index([10, 10, 9, 8, 4, 2, 4, 4, 2, 2, 10, 10], dtype='int64') - >>> psidx.delete(0).sort_values() # doctest: +SKIP - Int64Index([2, 2, 2, 4, 4, 4, 8, 9, 10, 10, 10], dtype='int64') + >>> psidx.delete(0).sort_values() + Index([2, 2, 2, 4, 4, 4, 8, 9, 10, 10, 10], dtype='int64') - >>> psidx.delete([0, 1, 2, 3, 10, 11]).sort_values() # doctest: +SKIP - Int64Index([2, 2, 2, 4, 4, 4], dtype='int64') + >>> psidx.delete([0, 1, 2, 3, 10, 11]).sort_values() + Index([2, 2, 2, 4, 4, 4], dtype='int64') MultiIndex @@ -1888,11 +1875,11 @@ def append(self, other: "Index") -> "Index": Examples -------- >>> psidx = ps.Index([10, 5, 0, 5, 10, 5, 0, 10]) - >>> psidx # doctest: +SKIP - Int64Index([10, 5, 0, 5, 10, 5, 0, 10], dtype='int64') + >>> psidx + Index([10, 5, 0, 5, 10, 5, 0, 10], dtype='int64') - >>> psidx.append(psidx) # doctest: +SKIP - Int64Index([10, 5, 0, 5, 10, 5, 0, 10, 10, 5, 0, 5, 10, 5, 0, 10], dtype='int64') + >>> psidx.append(psidx) + Index([10, 5, 0, 5, 10, 5, 0, 10, 10, 5, 0, 5, 10, 5, 0, 10], dtype='int64') Support for MiltiIndex @@ -1962,8 +1949,8 @@ def argmax(self) -> int: Examples -------- >>> psidx = ps.Index([10, 9, 8, 7, 100, 5, 4, 3, 100, 3]) - >>> psidx # doctest: +SKIP - Int64Index([10, 9, 8, 7, 100, 5, 4, 3, 100, 3], dtype='int64') + >>> psidx + Index([10, 9, 8, 7, 100, 5, 4, 3, 100, 3], dtype='int64') >>> psidx.argmax() 4 @@ -2010,8 +1997,8 @@ def argmin(self) -> int: Examples -------- >>> psidx = ps.Index([10, 9, 8, 7, 100, 5, 4, 3, 100, 3]) - >>> psidx # doctest: +SKIP - Int64Index([10, 9, 8, 7, 100, 5, 4, 3, 100, 3], dtype='int64') + >>> psidx + Index([10, 9, 8, 7, 100, 5, 4, 3, 100, 3], dtype='int64') >>> psidx.argmin() 7 @@ -2062,11 +2049,11 @@ def set_names( Examples -------- >>> idx = ps.Index([1, 2, 3, 4]) - >>> idx # doctest: +SKIP - Int64Index([1, 2, 3, 4], dtype='int64') + >>> idx + Index([1, 2, 3, 4], dtype='int64') - >>> idx.set_names('quarter') # doctest: +SKIP - Int64Index([1, 2, 3, 4], dtype='int64', name='quarter') + >>> idx.set_names('quarter') + Index([1, 2, 3, 4], dtype='int64', name='quarter') For MultiIndex @@ -2119,8 +2106,8 @@ def difference(self, other: "Index", sort: Optional[bool] = None) -> "Index": >>> idx1 = ps.Index([2, 1, 3, 4]) >>> idx2 = ps.Index([3, 4, 5, 6]) - >>> idx1.difference(idx2, sort=True) # doctest: +SKIP - Int64Index([1, 2], dtype='int64') + >>> idx1.difference(idx2, sort=True) + Index([1, 2], dtype='int64') MultiIndex @@ -2136,7 +2123,7 @@ def difference(self, other: "Index", sort: Optional[bool] = None) -> "Index": # Check if the `self` and `other` have different index types. # 1. `self` is Index, `other` is MultiIndex # 2. `self` is MultiIndex, `other` is Index - is_index_types_different = isinstance(other, Index) and not isinstance(self, type(other)) + is_index_types_different = isinstance(other, Index) and (type(self) != type(other)) if is_index_types_different: if isinstance(self, MultiIndex): # In case `self` is MultiIndex and `other` is Index, @@ -2219,8 +2206,8 @@ def is_all_dates(self) -> bool: True >>> idx = ps.Index([0, 1, 2]) - >>> idx # doctest: +SKIP - Int64Index([0, 1, 2], dtype='int64') + >>> idx + Index([0, 1, 2], dtype='int64') >>> idx.is_all_dates False @@ -2403,8 +2390,8 @@ def union( >>> idx1 = ps.Index([1, 2, 3, 4]) >>> idx2 = ps.Index([3, 4, 5, 6]) - >>> idx1.union(idx2).sort_values() # doctest: +SKIP - Int64Index([1, 2, 3, 4, 5, 6], dtype='int64') + >>> idx1.union(idx2).sort_values() + Index([1, 2, 3, 4, 5, 6], dtype='int64') MultiIndex @@ -2469,8 +2456,8 @@ def holds_integer(self) -> bool: When Index contains null values the result can be different with pandas since pandas-on-Spark cast integer to float when Index contains null values. - >>> ps.Index([1, 2, 3, None]) # doctest: +SKIP - Float64Index([1.0, 2.0, 3.0, nan], dtype='float64') + >>> ps.Index([1, 2, 3, None]) + Index([1.0, 2.0, 3.0, nan], dtype='float64') Examples -------- @@ -2510,8 +2497,8 @@ def intersection(self, other: Union[DataFrame, Series, "Index", List]) -> "Index -------- >>> idx1 = ps.Index([1, 2, 3, 4]) >>> idx2 = ps.Index([3, 4, 5, 6]) - >>> idx1.intersection(idx2).sort_values() # doctest: +SKIP - Int64Index([3, 4], dtype='int64') + >>> idx1.intersection(idx2).sort_values() + Index([3, 4], dtype='int64') """ from pyspark.pandas.indexes.multi import MultiIndex @@ -2599,14 +2586,14 @@ def insert(self, loc: int, item: Any) -> "Index": Examples -------- >>> psidx = ps.Index([1, 2, 3, 4, 5]) - >>> psidx.insert(3, 100) # doctest: +SKIP - Int64Index([1, 2, 3, 100, 4, 5], dtype='int64') + >>> psidx.insert(3, 100) + Index([1, 2, 3, 100, 4, 5], dtype='int64') For negative values >>> psidx = ps.Index([1, 2, 3, 4, 5]) - >>> psidx.insert(-3, 100) # doctest: +SKIP - Int64Index([1, 2, 100, 3, 4, 5], dtype='int64') + >>> psidx.insert(-3, 100) + Index([1, 2, 100, 3, 4, 5], dtype='int64') """ validate_index_loc(self, loc) loc = loc + len(self) if loc < 0 else loc diff --git a/python/pyspark/pandas/indexes/category.py b/python/pyspark/pandas/indexes/category.py index 7bc87805e1552..94725f90679a6 100644 --- a/python/pyspark/pandas/indexes/category.py +++ b/python/pyspark/pandas/indexes/category.py @@ -141,8 +141,8 @@ def codes(self) -> Index: CategoricalIndex(['a', 'b', 'b', 'c', 'c', 'c'], categories=['a', 'b', 'c'], ordered=False, dtype='category') - >>> idx.codes # doctest: +SKIP - Int64Index([0, 1, 1, 2, 2, 2], dtype='int64') + >>> idx.codes + Index([0, 1, 1, 2, 2, 2], dtype='int8') """ return self._with_new_scol( self.spark.column, diff --git a/python/pyspark/pandas/indexes/datetimes.py b/python/pyspark/pandas/indexes/datetimes.py index 9adef61087a9e..1971d90a74272 100644 --- a/python/pyspark/pandas/indexes/datetimes.py +++ b/python/pyspark/pandas/indexes/datetimes.py @@ -261,7 +261,7 @@ def dayofweek(self) -> Index: -------- >>> idx = ps.date_range('2016-12-31', '2017-01-08', freq='D') # doctest: +SKIP >>> idx.dayofweek # doctest: +SKIP - Int64Index([5, 6, 0, 1, 2, 3, 4, 5, 6], dtype='int64') + Index([5, 6, 0, 1, 2, 3, 4, 5, 6], dtype='int64') """ warnings.warn( "`dayofweek` will return int32 index instead of int 64 index in 4.0.0.", @@ -737,13 +737,13 @@ def indexer_between_time( dtype='datetime64[ns]', freq=None) >>> psidx.indexer_between_time("00:01", "00:02").sort_values() # doctest: +SKIP - Int64Index([1, 2], dtype='int64') + Index([1, 2], dtype='int64') >>> psidx.indexer_between_time("00:01", "00:02", include_end=False) # doctest: +SKIP - Int64Index([1], dtype='int64') + Index([1], dtype='int64') >>> psidx.indexer_between_time("00:01", "00:02", include_start=False) # doctest: +SKIP - Int64Index([2], dtype='int64') + Index([2], dtype='int64') """ def pandas_between_time(pdf) -> ps.DataFrame[int]: # type: ignore[no-untyped-def] @@ -783,10 +783,10 @@ def indexer_at_time(self, time: Union[datetime.time, str], asof: bool = False) - dtype='datetime64[ns]', freq=None) >>> psidx.indexer_at_time("00:00") # doctest: +SKIP - Int64Index([0], dtype='int64') + Index([0], dtype='int64') >>> psidx.indexer_at_time("00:01") # doctest: +SKIP - Int64Index([1], dtype='int64') + Index([1], dtype='int64') """ if asof: raise NotImplementedError("'asof' argument is not supported") diff --git a/python/pyspark/pandas/indexes/numeric.py b/python/pyspark/pandas/indexes/numeric.py deleted file mode 100644 index d0b5bc5d15989..0000000000000 --- a/python/pyspark/pandas/indexes/numeric.py +++ /dev/null @@ -1,210 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import warnings -from typing import Any, Optional, Union, cast - -import pandas as pd -from pandas.api.types import is_hashable # type: ignore[attr-defined] - -from pyspark import pandas as ps -from pyspark.pandas._typing import Dtype, Name -from pyspark.pandas.indexes.base import Index -from pyspark.pandas.series import Series - - -class NumericIndex(Index): - """ - Provide numeric type operations. - This is an abstract class. - """ - - pass - - -class IntegerIndex(NumericIndex): - """ - This is an abstract class for Int64Index. - """ - - pass - - -class Int64Index(IntegerIndex): - """ - Immutable sequence used for indexing and alignment. The basic object - storing axis labels for all pandas objects. Int64Index is a special case - of `Index` with purely integer labels. - - .. deprecated:: 3.4.0 - - Parameters - ---------- - data : array-like (1-dimensional) - dtype : NumPy dtype (default: int64) - copy : bool - Make a copy of input ndarray. - name : object - Name to be stored in the index. - - See Also - -------- - Index : The base pandas-on-Spark Index type. - Float64Index : A special case of :class:`Index` with purely float labels. - - Notes - ----- - An Index instance can **only** contain hashable objects. - - Examples - -------- - >>> ps.Int64Index([1, 2, 3]) # doctest: +SKIP - Int64Index([1, 2, 3], dtype='int64') - - From a Series: - - >>> s = ps.Series([1, 2, 3], index=[10, 20, 30]) - >>> ps.Int64Index(s) # doctest: +SKIP - Int64Index([1, 2, 3], dtype='int64') - - From an Index: - - >>> idx = ps.Index([1, 2, 3]) - >>> ps.Int64Index(idx) # doctest: +SKIP - Int64Index([1, 2, 3], dtype='int64') - """ - - def __new__( - cls, - data: Optional[Any] = None, - dtype: Optional[Union[str, Dtype]] = None, - copy: bool = False, - name: Optional[Name] = None, - ) -> "Int64Index": - warnings.warn( - "Int64Index is deprecated in 3.4.0, and will be removed in 4.0.0. Use Index instead.", - FutureWarning, - ) - if not is_hashable(name): - raise TypeError("Index.name must be a hashable type") - - if isinstance(data, (Series, Index)): - if dtype is None: - dtype = "int64" - return cast(Int64Index, Index(data, dtype=dtype, copy=copy, name=name)) - - return cast( - Int64Index, ps.from_pandas(pd.Int64Index(data=data, dtype=dtype, copy=copy, name=name)) - ) - - -class Float64Index(NumericIndex): - """ - Immutable sequence used for indexing and alignment. The basic object - storing axis labels for all pandas objects. Float64Index is a special case - of `Index` with purely float labels. - - .. deprecated:: 3.4.0 - - Parameters - ---------- - data : array-like (1-dimensional) - dtype : NumPy dtype (default: float64) - copy : bool - Make a copy of input ndarray. - name : object - Name to be stored in the index. - - See Also - -------- - Index : The base pandas-on-Spark Index type. - Int64Index : A special case of :class:`Index` with purely integer labels. - - Notes - ----- - An Index instance can **only** contain hashable objects. - - Examples - -------- - >>> ps.Float64Index([1.0, 2.0, 3.0]) # doctest: +SKIP - Float64Index([1.0, 2.0, 3.0], dtype='float64') - - From a Series: - - >>> s = ps.Series([1, 2, 3], index=[10, 20, 30]) - >>> ps.Float64Index(s) # doctest: +SKIP - Float64Index([1.0, 2.0, 3.0], dtype='float64') - - From an Index: - - >>> idx = ps.Index([1, 2, 3]) - >>> ps.Float64Index(idx) # doctest: +SKIP - Float64Index([1.0, 2.0, 3.0], dtype='float64') - """ - - def __new__( - cls, - data: Optional[Any] = None, - dtype: Optional[Union[str, Dtype]] = None, - copy: bool = False, - name: Optional[Name] = None, - ) -> "Float64Index": - warnings.warn( - "Float64Index is deprecated in 3.4.0, and will be removed in 4.0.0. Use Index instead.", - FutureWarning, - ) - if not is_hashable(name): - raise TypeError("Index.name must be a hashable type") - - if isinstance(data, (Series, Index)): - if dtype is None: - dtype = "float64" - return cast(Float64Index, Index(data, dtype=dtype, copy=copy, name=name)) - - return cast( - Float64Index, - ps.from_pandas(pd.Float64Index(data=data, dtype=dtype, copy=copy, name=name)), - ) - - -def _test() -> None: - import os - import doctest - import sys - from pyspark.sql import SparkSession - import pyspark.pandas.indexes.numeric - - os.chdir(os.environ["SPARK_HOME"]) - - globs = pyspark.pandas.indexes.numeric.__dict__.copy() - globs["ps"] = pyspark.pandas - spark = ( - SparkSession.builder.master("local[4]") - .appName("pyspark.pandas.indexes.numeric tests") - .getOrCreate() - ) - (failure_count, test_count) = doctest.testmod( - pyspark.pandas.indexes.numeric, - globs=globs, - optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE, - ) - spark.stop() - if failure_count: - sys.exit(-1) - - -if __name__ == "__main__": - _test() diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py index 9fbbadd5420a8..a74f36986f3b5 100644 --- a/python/pyspark/pandas/series.py +++ b/python/pyspark/pandas/series.py @@ -492,8 +492,8 @@ def axes(self) -> List["Index"]: -------- >>> psser = ps.Series([1, 2, 3]) - >>> psser.axes # doctest: +SKIP - [Int64Index([0, 1, 2], dtype='int64')] + >>> psser.axes + [Index([0, 1, 2], dtype='int64')] """ return [self.index] diff --git a/python/pyspark/pandas/spark/accessors.py b/python/pyspark/pandas/spark/accessors.py index f55f70e00924b..bcbe044185a75 100644 --- a/python/pyspark/pandas/spark/accessors.py +++ b/python/pyspark/pandas/spark/accessors.py @@ -105,8 +105,8 @@ def transform(self, func: Callable[[PySparkColumn], PySparkColumn]) -> IndexOpsL 2 1.098612 Name: a, dtype: float64 - >>> df.index.spark.transform(lambda c: c + 10) # doctest: +SKIP - Int64Index([10, 11, 12], dtype='int64') + >>> df.index.spark.transform(lambda c: c + 10) + Index([10, 11, 12], dtype='int64') >>> df.a.spark.transform(lambda c: c + df.b.spark.column) 0 5 @@ -283,13 +283,13 @@ def analyzed(self) -> "ps.Index": -------- >>> import pyspark.pandas as ps >>> idx = ps.Index([1, 2, 3]) - >>> idx # doctest: +SKIP - Int64Index([1, 2, 3], dtype='int64') + >>> idx + Index([1, 2, 3], dtype='int64') The analyzed one should return the same value. - >>> idx.spark.analyzed # doctest: +SKIP - Int64Index([1, 2, 3], dtype='int64') + >>> idx.spark.analyzed + Index([1, 2, 3], dtype='int64') However, it won't work with the same anchor Index. @@ -299,8 +299,8 @@ def analyzed(self) -> "ps.Index": ValueError: ... enable 'compute.ops_on_diff_frames' option. >>> with ps.option_context('compute.ops_on_diff_frames', True): - ... (idx + idx.spark.analyzed).sort_values() # doctest: +SKIP - Int64Index([2, 4, 6], dtype='int64') + ... (idx + idx.spark.analyzed).sort_values() + Index([2, 4, 6], dtype='int64') """ from pyspark.pandas.frame import DataFrame diff --git a/python/pyspark/pandas/tests/indexes/test_base.py b/python/pyspark/pandas/tests/indexes/test_base.py index 6cb7c58197f3c..736c88db4a8f5 100644 --- a/python/pyspark/pandas/tests/indexes/test_base.py +++ b/python/pyspark/pandas/tests/indexes/test_base.py @@ -42,10 +42,6 @@ def pdf(self): index=[0, 1, 3, 5, 6, 8, 9, 9, 9], ) - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43606): Enable IndexesTests.test_index_basic for pandas 2.0.0.", - ) def test_index_basic(self): for pdf in [ pd.DataFrame(np.random.randn(10, 5), index=np.random.randint(100, size=10)), @@ -70,22 +66,12 @@ def test_index_basic(self): self.assert_eq(type(psdf.index).__name__, type(pdf.index).__name__) self.assert_eq(ps.Index([])._summary(), "Index: 0 entries") - if LooseVersion(pd.__version__) >= LooseVersion("2.0.0"): - with self.assertRaisesRegexp(ValueError, "The truth value of a Index is ambiguous."): - bool(ps.Index([1])) - with self.assertRaisesRegexp(TypeError, "Index.name must be a hashable type"): - ps.Index([1, 2, 3], name=[(1, 2, 3)]) - with self.assertRaisesRegexp(TypeError, "Index.name must be a hashable type"): - ps.Index([1.0, 2.0, 3.0], name=[(1, 2, 3)]) - else: - with self.assertRaisesRegexp( - ValueError, "The truth value of a Int64Index is ambiguous." - ): - bool(ps.Index([1])) - with self.assertRaisesRegexp(TypeError, "Index.name must be a hashable type"): - ps.Int64Index([1, 2, 3], name=[(1, 2, 3)]) - with self.assertRaisesRegexp(TypeError, "Index.name must be a hashable type"): - ps.Float64Index([1.0, 2.0, 3.0], name=[(1, 2, 3)]) + with self.assertRaisesRegexp(ValueError, "The truth value of a Index is ambiguous."): + bool(ps.Index([1])) + with self.assertRaisesRegexp(TypeError, "Index.name must be a hashable type"): + ps.Index([1, 2, 3], name=[(1, 2, 3)]) + with self.assertRaisesRegexp(TypeError, "Index.name must be a hashable type"): + ps.Index([1.0, 2.0, 3.0], name=[(1, 2, 3)]) def test_index_from_series(self): pser = pd.Series([1, 2, 3], name="a", index=[10, 20, 30]) @@ -95,15 +81,8 @@ def test_index_from_series(self): self.assert_eq(ps.Index(psser, dtype="float"), pd.Index(pser, dtype="float")) self.assert_eq(ps.Index(psser, name="x"), pd.Index(pser, name="x")) - if LooseVersion(pd.__version__) >= LooseVersion("2.0.0"): - self.assert_eq(ps.Index(psser, dtype="int64"), pd.Index(pser, dtype="int64")) - self.assert_eq(ps.Index(psser, dtype="float64"), pd.Index(pser, dtype="float64")) - elif LooseVersion(pd.__version__) >= LooseVersion("1.1"): - self.assert_eq(ps.Int64Index(psser), pd.Int64Index(pser)) - self.assert_eq(ps.Float64Index(psser), pd.Float64Index(pser)) - else: - self.assert_eq(ps.Int64Index(psser), pd.Int64Index(pser).rename("a")) - self.assert_eq(ps.Float64Index(psser), pd.Float64Index(pser).rename("a")) + self.assert_eq(ps.Index(psser, dtype="int64"), pd.Index(pser, dtype="int64")) + self.assert_eq(ps.Index(psser, dtype="float64"), pd.Index(pser, dtype="float64")) pser = pd.Series([datetime(2021, 3, 1), datetime(2021, 3, 2)], name="x", index=[10, 20]) psser = ps.from_pandas(pser) @@ -120,12 +99,8 @@ def test_index_from_index(self): self.assert_eq(ps.Index(psidx, name="x"), pd.Index(pidx, name="x")) self.assert_eq(ps.Index(psidx, copy=True), pd.Index(pidx, copy=True)) - if LooseVersion(pd.__version__) >= LooseVersion("2.0.0"): - self.assert_eq(ps.Index(psidx, dtype="int64"), pd.Index(pidx, dtype="int64")) - self.assert_eq(ps.Index(psidx, dtype="float64"), pd.Index(pidx, dtype="float64")) - else: - self.assert_eq(ps.Int64Index(psidx), pd.Int64Index(pidx)) - self.assert_eq(ps.Float64Index(psidx), pd.Float64Index(pidx)) + self.assert_eq(ps.Index(psidx, dtype="int64"), pd.Index(pidx, dtype="int64")) + self.assert_eq(ps.Index(psidx, dtype="float64"), pd.Index(pidx, dtype="float64")) pidx = pd.DatetimeIndex(["2021-03-01", "2021-03-02"]) psidx = ps.from_pandas(pidx) diff --git a/python/pyspark/pandas/tests/series/test_compute.py b/python/pyspark/pandas/tests/series/test_compute.py index 7d39f0523d456..155649179e6ef 100644 --- a/python/pyspark/pandas/tests/series/test_compute.py +++ b/python/pyspark/pandas/tests/series/test_compute.py @@ -471,7 +471,7 @@ def test_factorize(self): pcodes, puniques = pser.factorize() kcodes, kuniques = psser.factorize() self.assert_eq(pcodes, kcodes.to_list()) - # pandas: Float64Index([], dtype='float64') + # pandas: Index([], dtype='float64') self.assert_eq(pd.Index([]), kuniques) pser = pd.Series([np.nan, np.nan]) @@ -479,7 +479,7 @@ def test_factorize(self): pcodes, puniques = pser.factorize() kcodes, kuniques = psser.factorize() self.assert_eq(pcodes, kcodes.to_list()) - # pandas: Float64Index([], dtype='float64') + # pandas: Index([], dtype='float64') self.assert_eq(pd.Index([]), kuniques) # diff --git a/python/pyspark/pandas/tests/series/test_series.py b/python/pyspark/pandas/tests/series/test_series.py index f7f186b672452..136d905eb494b 100644 --- a/python/pyspark/pandas/tests/series/test_series.py +++ b/python/pyspark/pandas/tests/series/test_series.py @@ -688,7 +688,8 @@ def test_dot(self): psdf_other = ps.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, index=["x", "y", "z"]) with self.assertRaisesRegex(ValueError, "matrices are not aligned"): - psdf["b"].dot(psdf_other) + with ps.option_context("compute.ops_on_diff_frames", True): + psdf["b"].dot(psdf_other) def test_tail(self): pser = pd.Series(range(1000), name="Koalas") diff --git a/python/pyspark/pandas/usage_logging/__init__.py b/python/pyspark/pandas/usage_logging/__init__.py index e14a905e78a04..4478b6c85f662 100644 --- a/python/pyspark/pandas/usage_logging/__init__.py +++ b/python/pyspark/pandas/usage_logging/__init__.py @@ -29,7 +29,6 @@ from pyspark.pandas.indexes.category import CategoricalIndex from pyspark.pandas.indexes.datetimes import DatetimeIndex from pyspark.pandas.indexes.multi import MultiIndex -from pyspark.pandas.indexes.numeric import Float64Index, Int64Index from pyspark.pandas.missing.frame import MissingPandasLikeDataFrame from pyspark.pandas.missing.general_functions import MissingPandasLikeGeneralFunctions from pyspark.pandas.missing.groupby import ( @@ -89,8 +88,6 @@ def attach(logger_module: Union[str, ModuleType]) -> None: Series, Index, MultiIndex, - Int64Index, - Float64Index, CategoricalIndex, DatetimeIndex, DataFrameGroupBy, From 42e5daddf3ba16ff7d08e82e51cd8924cc56e180 Mon Sep 17 00:00:00 2001 From: Yihong He Date: Tue, 8 Aug 2023 06:33:48 +0900 Subject: [PATCH 56/68] [SPARK-44575][SQL][CONNECT] Implement basic error translation ### What changes were proposed in this pull request? - Implement basic error translation for spark connect scala client. ### Why are the changes needed? - Better compatibility with the existing control flow ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? `build/sbt "connect-client-jvm/testOnly *Suite"` Closes #42266 from heyihong/SPARK-44575. Authored-by: Yihong He Signed-off-by: Hyukjin Kwon --- .../client/GrpcExceptionConverter.scala | 54 ++++++++++++++++--- .../org/apache/spark/sql/CatalogSuite.scala | 6 +-- .../apache/spark/sql/ClientE2ETestSuite.scala | 12 +++-- .../spark/sql/DataFrameNaFunctionSuite.scala | 3 +- .../KeyValueGroupedDatasetE2ETestSuite.scala | 3 +- 5 files changed, 60 insertions(+), 18 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala index 7ff3421a5a045..64d1e5c488ab4 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala @@ -16,18 +16,26 @@ */ package org.apache.spark.sql.connect.client +import scala.jdk.CollectionConverters._ +import scala.reflect.ClassTag + +import com.google.rpc.ErrorInfo import io.grpc.StatusRuntimeException import io.grpc.protobuf.StatusProto -import org.apache.spark.{SparkException, SparkThrowable} +import org.apache.spark.SparkException +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.util.JsonUtils -private[client] object GrpcExceptionConverter { +private[client] object GrpcExceptionConverter extends JsonUtils { def convert[T](f: => T): T = { try { f } catch { case e: StatusRuntimeException => - throw toSparkThrowable(e) + throw toThrowable(e) } } @@ -53,11 +61,41 @@ private[client] object GrpcExceptionConverter { } } - private def toSparkThrowable(ex: StatusRuntimeException): SparkThrowable with Throwable = { - val status = StatusProto.fromThrowable(ex) - // TODO: Add finer grained error conversion - new SparkException(status.getMessage, ex.getCause) + private def errorConstructor[T <: Throwable: ClassTag]( + throwableCtr: (String, Throwable) => T): (String, (String, Throwable) => Throwable) = { + val className = implicitly[reflect.ClassTag[T]].runtimeClass.getName + (className, throwableCtr) } -} + private val errorFactory = Map( + errorConstructor((message, _) => new ParseException(None, message, Origin(), Origin())), + errorConstructor((message, cause) => new AnalysisException(message, cause = Option(cause)))) + + private def errorInfoToThrowable(info: ErrorInfo, message: String): Option[Throwable] = { + val classes = + mapper.readValue(info.getMetadataOrDefault("classes", "[]"), classOf[Array[String]]) + classes + .find(errorFactory.contains) + .map { cls => + val constructor = errorFactory.get(cls).get + constructor(message, null) + } + } + + private def toThrowable(ex: StatusRuntimeException): Throwable = { + val status = StatusProto.fromThrowable(ex) + + val fallbackEx = new SparkException(status.getMessage, ex.getCause) + + val errorInfoOpt = status.getDetailsList.asScala + .find(_.is(classOf[ErrorInfo])) + + if (errorInfoOpt.isEmpty) { + return fallbackEx + } + + errorInfoToThrowable(errorInfoOpt.get.unpack(classOf[ErrorInfo]), status.getMessage) + .getOrElse(fallbackEx) + } +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala index 00a6bcc9b5c45..fa97498f7e77a 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala @@ -46,7 +46,7 @@ class CatalogSuite extends RemoteSparkSession with SQLHelper { assert(databasesWithPattern.length == 0) val database = spark.catalog.getDatabase(db) assert(database.name == db) - val message = intercept[SparkException] { + val message = intercept[AnalysisException] { spark.catalog.getDatabase("notExists") }.getMessage assert(message.contains("SCHEMA_NOT_FOUND")) @@ -141,7 +141,7 @@ class CatalogSuite extends RemoteSparkSession with SQLHelper { assert(spark.catalog.listTables().collect().map(_.name).toSet == Set(parquetTableName)) } } - val message = intercept[SparkException] { + val message = intercept[AnalysisException] { spark.catalog.getTable(parquetTableName) }.getMessage assert(message.contains("TABLE_OR_VIEW_NOT_FOUND")) @@ -207,7 +207,7 @@ class CatalogSuite extends RemoteSparkSession with SQLHelper { assert(spark.catalog.getFunction(absFunctionName).name == absFunctionName) val notExistsFunction = "notExists" assert(!spark.catalog.functionExists(notExistsFunction)) - val message = intercept[SparkException] { + val message = intercept[AnalysisException] { spark.catalog.getFunction(notExistsFunction) }.getMessage assert(message.contains("UNRESOLVED_ROUTINE")) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 98fbff84ba674..ebd3d037bba5c 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -43,6 +43,12 @@ import org.apache.spark.sql.types._ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateMethodTester { + test("throw ParseException") { + intercept[ParseException] { + spark.sql("selet 1").collect() + } + } + test("spark deep recursion") { var df = spark.range(1) for (a <- 1 to 500) { @@ -88,7 +94,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM assume(IntegrationTestUtils.isSparkHiveJarAvailable) withTable("test_martin") { // Fails, because table does not exist. - assertThrows[SparkException] { + assertThrows[AnalysisException] { spark.sql("select * from test_martin").collect() } // Execute eager, DML @@ -177,7 +183,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM StructField("job", StringType) :: Nil)) .csv(testDataPath.toString) // Failed because the path cannot be provided both via option and load method (csv). - assertThrows[SparkException] { + assertThrows[AnalysisException] { df.collect() } } @@ -381,7 +387,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM val df = spark.range(10) val outputFolderPath = Files.createTempDirectory("output").toAbsolutePath // Failed because the path cannot be provided both via option and save method. - assertThrows[SparkException] { + assertThrows[AnalysisException] { df.write.option("path", outputFolderPath.toString).save(outputFolderPath.toString) } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala index 525a5902525ad..ac64d4411a866 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ -import org.apache.spark.SparkException import org.apache.spark.sql.connect.client.util.QueryTest import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types.{StringType, StructType} @@ -279,7 +278,7 @@ class DataFrameNaFunctionSuite extends QueryTest with SQLHelper { test("drop with col(*)") { val df = createDF() - val ex = intercept[SparkException] { + val ex = intercept[AnalysisException] { df.na.drop("any", Seq("*")).collect() } assert(ex.getMessage.contains("UNRESOLVED_COLUMN.WITH_SUGGESTION")) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala index ad75887a7e2db..380ca2fb72b31 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import java.sql.Timestamp import java.util.Arrays -import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Append import org.apache.spark.sql.connect.client.util.QueryTest import org.apache.spark.sql.functions._ @@ -179,7 +178,7 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { assert(values == Arrays.asList[String]("0", "8,6,4,2,0", "1", "9,7,5,3,1")) // Star is not allowed as group sort column - val message = intercept[SparkException] { + val message = intercept[AnalysisException] { grouped .flatMapSortedGroups(col("*")) { (g, iter) => Iterator(String.valueOf(g), iter.mkString(",")) From 4eea89d339649152a1afcd8b7a32020454e71d42 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 8 Aug 2023 00:42:13 +0200 Subject: [PATCH 57/68] [SPARK-44692][CONNECT][SQL] Move Trigger(s) to sql/api ### What changes were proposed in this pull request? This PR moves `Triggers.scala` and `Trigger.scala` from `sql/core` to `sql/api`, and it removes the duplicates from the connect scala client. ### Why are the changes needed? Not really needed, just some deduplication. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. Closes #42368 from hvanhovell/SPARK-44692. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- .../apache/spark/sql/streaming/Trigger.java | 180 ------------------ dev/checkstyle-suppressions.xml | 4 +- project/MimaExcludes.scala | 4 +- .../apache/spark/sql/streaming/Trigger.java | 0 .../sql/execution/streaming/Triggers.scala | 6 +- .../sql/execution/streaming/Triggers.scala | 113 ----------- 6 files changed, 6 insertions(+), 301 deletions(-) delete mode 100644 connector/connect/client/jvm/src/main/java/org/apache/spark/sql/streaming/Trigger.java rename sql/{core => api}/src/main/java/org/apache/spark/sql/streaming/Trigger.java (100%) rename {connector/connect/client/jvm => sql/api}/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala (96%) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala diff --git a/connector/connect/client/jvm/src/main/java/org/apache/spark/sql/streaming/Trigger.java b/connector/connect/client/jvm/src/main/java/org/apache/spark/sql/streaming/Trigger.java deleted file mode 100644 index 27ffe67d9909c..0000000000000 --- a/connector/connect/client/jvm/src/main/java/org/apache/spark/sql/streaming/Trigger.java +++ /dev/null @@ -1,180 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.streaming; - -import java.util.concurrent.TimeUnit; - -import scala.concurrent.duration.Duration; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.execution.streaming.AvailableNowTrigger$; -import org.apache.spark.sql.execution.streaming.ContinuousTrigger; -import org.apache.spark.sql.execution.streaming.OneTimeTrigger$; -import org.apache.spark.sql.execution.streaming.ProcessingTimeTrigger; - -/** - * Policy used to indicate how often results should be produced by a [[StreamingQuery]]. - * - * @since 3.5.0 - */ -@Evolving -public class Trigger { - // This is a copy of the same class in sql/core/.../streaming/Trigger.java - - /** - * A trigger policy that runs a query periodically based on an interval in processing time. - * If `interval` is 0, the query will run as fast as possible. - * - * @since 3.5.0 - */ - public static Trigger ProcessingTime(long intervalMs) { - return ProcessingTimeTrigger.create(intervalMs, TimeUnit.MILLISECONDS); - } - - /** - * (Java-friendly) - * A trigger policy that runs a query periodically based on an interval in processing time. - * If `interval` is 0, the query will run as fast as possible. - * - * {{{ - * import java.util.concurrent.TimeUnit - * df.writeStream().trigger(Trigger.ProcessingTime(10, TimeUnit.SECONDS)) - * }}} - * - * @since 3.5.0 - */ - public static Trigger ProcessingTime(long interval, TimeUnit timeUnit) { - return ProcessingTimeTrigger.create(interval, timeUnit); - } - - /** - * (Scala-friendly) - * A trigger policy that runs a query periodically based on an interval in processing time. - * If `duration` is 0, the query will run as fast as possible. - * - * {{{ - * import scala.concurrent.duration._ - * df.writeStream.trigger(Trigger.ProcessingTime(10.seconds)) - * }}} - * @since 3.5.0 - */ - public static Trigger ProcessingTime(Duration interval) { - return ProcessingTimeTrigger.apply(interval); - } - - /** - * A trigger policy that runs a query periodically based on an interval in processing time. - * If `interval` is effectively 0, the query will run as fast as possible. - * - * {{{ - * df.writeStream.trigger(Trigger.ProcessingTime("10 seconds")) - * }}} - * @since 3.5.0 - */ - public static Trigger ProcessingTime(String interval) { - return ProcessingTimeTrigger.apply(interval); - } - - /** - * A trigger that processes all available data in a single batch then terminates the query. - * - * @since 3.5.0 - * @deprecated This is deprecated as of Spark 3.4.0. Use {@link #AvailableNow()} to leverage - * better guarantee of processing, fine-grained scale of batches, and better gradual - * processing of watermark advancement including no-data batch. - * See the NOTES in {@link #AvailableNow()} for details. - */ - @Deprecated - public static Trigger Once() { - return OneTimeTrigger$.MODULE$; - } - - /** - * A trigger that processes all available data at the start of the query in one or multiple - * batches, then terminates the query. - * - * Users are encouraged to set the source options to control the size of the batch as similar as - * controlling the size of the batch in {@link #ProcessingTime(long)} trigger. - * - * NOTES: - * - This trigger provides a strong guarantee of processing: regardless of how many batches were - * left over in previous run, it ensures all available data at the time of execution gets - * processed before termination. All uncommitted batches will be processed first. - * - Watermark gets advanced per each batch, and no-data batch gets executed before termination - * if the last batch advances the watermark. This helps to maintain smaller and predictable - * state size and smaller latency on the output of stateful operators. - * - * @since 3.5.0 - */ - public static Trigger AvailableNow() { - return AvailableNowTrigger$.MODULE$; - } - - /** - * A trigger that continuously processes streaming data, asynchronously checkpointing at - * the specified interval. - * - * @since 3.5.0 - */ - public static Trigger Continuous(long intervalMs) { - return ContinuousTrigger.apply(intervalMs); - } - - /** - * A trigger that continuously processes streaming data, asynchronously checkpointing at - * the specified interval. - * - * {{{ - * import java.util.concurrent.TimeUnit - * df.writeStream.trigger(Trigger.Continuous(10, TimeUnit.SECONDS)) - * }}} - * - * @since 3.5.0 - */ - public static Trigger Continuous(long interval, TimeUnit timeUnit) { - return ContinuousTrigger.create(interval, timeUnit); - } - - /** - * (Scala-friendly) - * A trigger that continuously processes streaming data, asynchronously checkpointing at - * the specified interval. - * - * {{{ - * import scala.concurrent.duration._ - * df.writeStream.trigger(Trigger.Continuous(10.seconds)) - * }}} - * @since 3.5.0 - */ - public static Trigger Continuous(Duration interval) { - return ContinuousTrigger.apply(interval); - } - - /** - * A trigger that continuously processes streaming data, asynchronously checkpointing at - * the specified interval. - * - * {{{ - * df.writeStream.trigger(Trigger.Continuous("10 seconds")) - * }}} - * @since 3.5.0 - */ - public static Trigger Continuous(String interval) { - return ContinuousTrigger.apply(interval); - } -} diff --git a/dev/checkstyle-suppressions.xml b/dev/checkstyle-suppressions.xml index 44876fe69120d..8ba1ff1b3b1eb 100644 --- a/dev/checkstyle-suppressions.xml +++ b/dev/checkstyle-suppressions.xml @@ -57,9 +57,7 @@ - + files="sql/api/src/main/java/org/apache/spark/sql/streaming/Trigger.java"/> diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 9e5eb66ce94d0..8da132f5de3c5 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -77,7 +77,9 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.GroupStateTimeout"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.OutputMode"), // [SPARK-44198][CORE] Support propagation of the log level to the executors - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages$SparkAppConfig$") + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages$SparkAppConfig$"), + // [SPARK-44692][CONNECT][SQL] Move Trigger(s) to sql/api + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.Trigger") ) // Default exclude rules diff --git a/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java b/sql/api/src/main/java/org/apache/spark/sql/streaming/Trigger.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java rename to sql/api/src/main/java/org/apache/spark/sql/streaming/Trigger.java diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala b/sql/api/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala similarity index 96% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala rename to sql/api/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala index ad19ad1780549..37c5b314978bb 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala @@ -28,8 +28,6 @@ import org.apache.spark.sql.streaming.Trigger import org.apache.spark.unsafe.types.UTF8String private object Triggers { - // This is a copy of the same class in sql/core/...execution/streaming/Triggers.scala - def validate(intervalMs: Long): Unit = { require(intervalMs >= 0, "the interval of trigger should not be negative") } @@ -87,8 +85,8 @@ object ProcessingTimeTrigger { } /** - * A [[Trigger]] that continuously processes streaming data, asynchronously checkpointing at the - * specified interval. + * A [[Trigger]] that continuously processes streaming data, asynchronously checkpointing at + * the specified interval. */ case class ContinuousTrigger(intervalMs: Long) extends Trigger { Triggers.validate(intervalMs) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala deleted file mode 100644 index e6d1381b2b620..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.util.concurrent.TimeUnit - -import scala.concurrent.duration.Duration - -import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_DAY -import org.apache.spark.sql.catalyst.util.DateTimeUtils.microsToMillis -import org.apache.spark.sql.catalyst.util.IntervalUtils -import org.apache.spark.sql.streaming.Trigger -import org.apache.spark.unsafe.types.UTF8String - -private object Triggers { - def validate(intervalMs: Long): Unit = { - require(intervalMs >= 0, "the interval of trigger should not be negative") - } - - def convert(interval: String): Long = { - val cal = IntervalUtils.stringToInterval(UTF8String.fromString(interval)) - if (cal.months != 0) { - throw new IllegalArgumentException(s"Doesn't support month or year interval: $interval") - } - val microsInDays = Math.multiplyExact(cal.days, MICROS_PER_DAY) - microsToMillis(Math.addExact(cal.microseconds, microsInDays)) - } - - def convert(interval: Duration): Long = interval.toMillis - - def convert(interval: Long, unit: TimeUnit): Long = unit.toMillis(interval) -} - -/** - * A [[Trigger]] that processes all available data in one batch then terminates the query. - */ -case object OneTimeTrigger extends Trigger - -/** - * A [[Trigger]] that processes all available data in multiple batches then terminates the query. - */ -case object AvailableNowTrigger extends Trigger - -/** - * A [[Trigger]] that runs a query periodically based on the processing time. If `interval` is 0, - * the query will run as fast as possible. - */ -case class ProcessingTimeTrigger(intervalMs: Long) extends Trigger { - Triggers.validate(intervalMs) -} - -object ProcessingTimeTrigger { - import Triggers._ - - def apply(interval: String): ProcessingTimeTrigger = { - ProcessingTimeTrigger(convert(interval)) - } - - def apply(interval: Duration): ProcessingTimeTrigger = { - ProcessingTimeTrigger(convert(interval)) - } - - def create(interval: String): ProcessingTimeTrigger = { - apply(interval) - } - - def create(interval: Long, unit: TimeUnit): ProcessingTimeTrigger = { - ProcessingTimeTrigger(convert(interval, unit)) - } -} - -/** - * A [[Trigger]] that continuously processes streaming data, asynchronously checkpointing at - * the specified interval. - */ -case class ContinuousTrigger(intervalMs: Long) extends Trigger { - Triggers.validate(intervalMs) -} - -object ContinuousTrigger { - import Triggers._ - - def apply(interval: String): ContinuousTrigger = { - ContinuousTrigger(convert(interval)) - } - - def apply(interval: Duration): ContinuousTrigger = { - ContinuousTrigger(convert(interval)) - } - - def create(interval: String): ContinuousTrigger = { - apply(interval) - } - - def create(interval: Long, unit: TimeUnit): ContinuousTrigger = { - ContinuousTrigger(convert(interval, unit)) - } -} From 8911578020f8a2428b12dd72cb0ed4b7d747d835 Mon Sep 17 00:00:00 2001 From: Steven Aerts Date: Tue, 8 Aug 2023 08:09:05 +0900 Subject: [PATCH 58/68] [SPARK-44132][SQL] Materialize `Stream` of join column names to avoid codegen failure ### What changes were proposed in this pull request? Materialize passed join columns as an `IndexedSeq` before passing it to the lower layers. ### Why are the changes needed? When nesting multiple full outer joins using column names which are a `Stream`, the code generator will generate faulty code resulting in a NPE or bad `UnsafeRow` access at runtime. See the 2 added test cases. ### Why are the changes needed? Otherwise the code will crash, see the 2 added test cases. Which show an NPE and a bad `UnsafeRow` access. ### Does this PR introduce _any_ user-facing change? No, only bug fix. ### How was this patch tested? A reproduction scenario was created and added to the code base. Closes #41712 from steven-aerts/SPARK-44132-fix. Authored-by: Steven Aerts Signed-off-by: Hyukjin Kwon --- .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../execution/joins/JoinCodegenSupport.scala | 2 +- .../org/apache/spark/sql/JoinSuite.scala | 20 +++++++++++++++++++ 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 61c83829d2012..eda017937d918 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1092,7 +1092,7 @@ class Dataset[T] private[sql]( Join( joined.left, joined.right, - UsingJoin(JoinType(joinType), usingColumns), + UsingJoin(JoinType(joinType), usingColumns.toIndexedSeq), None, JoinHint.NONE) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala index a7d1edefcd611..6496f9a0006e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala @@ -79,7 +79,7 @@ trait JoinCodegenSupport extends CodegenSupport with BaseJoinExec { setDefaultValue: Boolean): Seq[ExprCode] = { ctx.currentVars = null ctx.INPUT_ROW = row - plan.output.zipWithIndex.map { case (a, i) => + plan.output.toIndexedSeq.zipWithIndex.map { case (a, i) => val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx) if (setDefaultValue) { // the variables are needed even there is no matched rows diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 7f358723eeb8f..14f1fb27906a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -1709,4 +1709,24 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan checkAnswer(sql(query), expected) } } + + test("SPARK-44132: FULL OUTER JOIN by streamed column name fails with NPE") { + val dsA = Seq((1, "a")).toDF("id", "c1") + val dsB = Seq((2, "b")).toDF("id", "c2") + val dsC = Seq((3, "c")).toDF("id", "c3") + val joined = dsA.join(dsB, Stream("id"), "full_outer").join(dsC, Stream("id"), "full_outer") + + val expected = Seq(Row(1, "a", null, null), Row(2, null, "b", null), Row(3, null, null, "c")) + + checkAnswer(joined, expected) + } + + test("SPARK-44132: FULL OUTER JOIN by streamed column name fails with invalid access") { + val ds = Seq((1, "a")).toDF("id", "c1") + val joined = ds.join(ds, Stream("id"), "full_outer").join(ds, Stream("id"), "full_outer") + + val expected = Seq(Row(1, "a", "a", "a")) + + checkAnswer(joined, expected) + } } From f47a2560e6e39ba8eac51a76290614b2fba4d65a Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Tue, 8 Aug 2023 08:10:09 +0900 Subject: [PATCH 59/68] [MINOR][UI] Increasing the number of significant digits for Fraction Cached of RDD ### What changes were proposed in this pull request? This PR is a typo improvement for increasing the number of significant digits for Fraction Cached of RDD that shows on the Storage Tab. ### Why are the changes needed? improves accuracy and precision ![image](https://github.com/apache/spark/assets/8326978/7106352c-b806-4953-8938-c4cba8ea1191) ### Does this PR introduce _any_ user-facing change? Yes, the Fraction Cached on Storage Page increases the fractional length from 0 to 2 ### How was this patch tested? locally verified Closes #42373 from yaooqinn/uiminor. Authored-by: Kent Yao Signed-off-by: Hyukjin Kwon --- .../org/apache/spark/ui/storage/StoragePage.scala | 2 +- .../spark/ui/storage/StoragePageSuite.scala | 15 ++++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index c1708c320c5d4..726622673650d 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -98,7 +98,7 @@ private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends {rdd.storageLevel} {rdd.numCachedPartitions.toString} - {"%.0f%%".format(rdd.numCachedPartitions * 100.0 / rdd.numPartitions)} + {"%.2f%%".format(rdd.numCachedPartitions * 100.0 / rdd.numPartitions)} {Utils.bytesToString(rdd.memoryUsed)} {Utils.bytesToString(rdd.diskUsed)} diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala index 718c6856cb31f..d1e25bf8a2346 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala @@ -48,8 +48,8 @@ class StoragePageSuite extends SparkFunSuite { val rdd2 = new RDDStorageInfo(2, "rdd2", - 10, - 5, + 1000, + 56, StorageLevel.DISK_ONLY.description, 0L, 200L, @@ -58,8 +58,8 @@ class StoragePageSuite extends SparkFunSuite { val rdd3 = new RDDStorageInfo(3, "rdd3", - 10, - 10, + 1000, + 103, StorageLevel.MEMORY_AND_DISK_SER.description, 400L, 500L, @@ -94,19 +94,20 @@ class StoragePageSuite extends SparkFunSuite { assert((xmlNodes \\ "tr").size === 3) assert(((xmlNodes \\ "tr")(0) \\ "td").map(_.text.trim) === - Seq("1", "rdd1", "Memory Deserialized 1x Replicated", "10", "100%", "100.0 B", "0.0 B")) + Seq("1", "rdd1", "Memory Deserialized 1x Replicated", "10", "100.00%", "100.0 B", "0.0 B")) // Check the url assert(((xmlNodes \\ "tr")(0) \\ "td" \ "a")(0).attribute("href").map(_.text) === Some("http://localhost:4040/storage/rdd/?id=1")) assert(((xmlNodes \\ "tr")(1) \\ "td").map(_.text.trim) === - Seq("2", "rdd2", "Disk Serialized 1x Replicated", "5", "50%", "0.0 B", "200.0 B")) + Seq("2", "rdd2", "Disk Serialized 1x Replicated", "56", "5.60%", "0.0 B", "200.0 B")) // Check the url assert(((xmlNodes \\ "tr")(1) \\ "td" \ "a")(0).attribute("href").map(_.text) === Some("http://localhost:4040/storage/rdd/?id=2")) assert(((xmlNodes \\ "tr")(2) \\ "td").map(_.text.trim) === - Seq("3", "rdd3", "Disk Memory Serialized 1x Replicated", "10", "100%", "400.0 B", "500.0 B")) + Seq("3", "rdd3", "Disk Memory Serialized 1x Replicated", "103", "10.30%", "400.0 B", + "500.0 B")) // Check the url assert(((xmlNodes \\ "tr")(2) \\ "td" \ "a")(0).attribute("href").map(_.text) === Some("http://localhost:4040/storage/rdd/?id=3")) From 9368a0f0c1001fb6fd64799a2e744874b6cd27e4 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 8 Aug 2023 11:03:05 +0900 Subject: [PATCH 60/68] [SPARK-44694][PYTHON][CONNECT] Refactor active sessions and expose them as an API ### What changes were proposed in this pull request? This PR proposes to (mostly) refactor all the internal workarounds to get the active session correctly. There are few things to note: - _PySpark without Spark Connect does not already support the hierarchy of active sessions_. With pinned thread mode (enabled by default), PySpark does map each Python thread to JVM thread, but the thread creation happens within gateway server, that does not respect the thread hierarchy. Therefore, this PR follows the exactly same behaviour. - New thread will not have an active thread by default. - Other behaviours are same as PySpark without Connect, see also https://github.com/apache/spark/pull/42367 - Since I am here, I piggiyback few documentation changes. We missed document `SparkSession.readStream`, `SparkSession.streams`, `SparkSession.udtf`, `SparkSession.conf` and `SparkSession.version` in Spark Connect. - The changes here are mostly refactoring that reuses existing unittests while I expose two methods: - `SparkSession.getActiveSession` (only for Spark Connect) - `SparkSession.active` (for both in PySpark) ### Why are the changes needed? For Spark Connect users to be able to play with active and default sessions in Python. ### Does this PR introduce _any_ user-facing change? Yes, it adds new API: - `SparkSession.getActiveSession` (only for Spark Connect) - `SparkSession.active` (for both in PySpark) ### How was this patch tested? Existing unittests should cover all. Closes #42371 from HyukjinKwon/SPARK-44694. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../reference/pyspark.sql/spark_session.rst | 1 + python/pyspark/errors/error_classes.py | 5 + python/pyspark/ml/connect/io_utils.py | 8 +- python/pyspark/ml/connect/tuning.py | 11 +- python/pyspark/ml/torch/distributor.py | 3 +- python/pyspark/ml/util.py | 13 --- python/pyspark/pandas/utils.py | 7 +- python/pyspark/sql/connect/session.py | 107 ++++++++++++------ python/pyspark/sql/connect/udf.py | 25 ++-- python/pyspark/sql/connect/udtf.py | 27 +++-- python/pyspark/sql/session.py | 65 +++++++++-- .../sql/tests/connect/test_connect_basic.py | 4 +- python/pyspark/sql/utils.py | 18 +++ 13 files changed, 197 insertions(+), 97 deletions(-) diff --git a/python/docs/source/reference/pyspark.sql/spark_session.rst b/python/docs/source/reference/pyspark.sql/spark_session.rst index c16ca4f162f5c..f25dbab5f6b9b 100644 --- a/python/docs/source/reference/pyspark.sql/spark_session.rst +++ b/python/docs/source/reference/pyspark.sql/spark_session.rst @@ -28,6 +28,7 @@ See also :class:`SparkSession`. .. autosummary:: :toctree: api/ + SparkSession.active SparkSession.builder.appName SparkSession.builder.config SparkSession.builder.enableHiveSupport diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index a534bc6deb41e..24885e94d3255 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -617,6 +617,11 @@ "Argument `` should be a WindowSpec, got ." ] }, + "NO_ACTIVE_OR_DEFAULT_SESSION" : { + "message" : [ + "No active or default Spark session found. Please create a new Spark session before running the code." + ] + }, "NO_ACTIVE_SESSION" : { "message" : [ "No active Spark session found. Please create a new Spark session before running the code." diff --git a/python/pyspark/ml/connect/io_utils.py b/python/pyspark/ml/connect/io_utils.py index 9a963086aaf45..a09a244862c58 100644 --- a/python/pyspark/ml/connect/io_utils.py +++ b/python/pyspark/ml/connect/io_utils.py @@ -23,7 +23,7 @@ from urllib.parse import urlparse from typing import Any, Dict, List from pyspark.ml.base import Params -from pyspark.ml.util import _get_active_session +from pyspark.sql import SparkSession from pyspark.sql.utils import is_remote @@ -34,7 +34,7 @@ def _copy_file_from_local_to_fs(local_path: str, dest_path: str) -> None: - session = _get_active_session(is_remote()) + session = SparkSession.active() if is_remote(): session.copyFromLocalToFs(local_path, dest_path) else: @@ -228,7 +228,7 @@ def save(self, path: str, *, overwrite: bool = False) -> None: .. versionadded:: 3.5.0 """ - session = _get_active_session(is_remote()) + session = SparkSession.active() path_exist = True try: session.read.format("binaryFile").load(path).head() @@ -256,7 +256,7 @@ def load(cls, path: str) -> "Params": .. versionadded:: 3.5.0 """ - session = _get_active_session(is_remote()) + session = SparkSession.active() tmp_local_dir = tempfile.mkdtemp(prefix="pyspark_ml_model_") try: diff --git a/python/pyspark/ml/connect/tuning.py b/python/pyspark/ml/connect/tuning.py index 6d539933e1d69..c22c31e84e8de 100644 --- a/python/pyspark/ml/connect/tuning.py +++ b/python/pyspark/ml/connect/tuning.py @@ -178,11 +178,12 @@ def _parallelFitTasks( def get_single_task(index: int, param_map: Any) -> Callable[[], Tuple[int, float]]: def single_task() -> Tuple[int, float]: - # Active session is thread-local variable, in background thread the active session - # is not set, the following line sets it as the main thread active session. - active_session._jvm.SparkSession.setActiveSession( # type: ignore[union-attr] - active_session._jsparkSession # type: ignore[union-attr] - ) + if not is_remote(): + # Active session is thread-local variable, in background thread the active session + # is not set, the following line sets it as the main thread active session. + active_session._jvm.SparkSession.setActiveSession( # type: ignore[union-attr] + active_session._jsparkSession # type: ignore[union-attr] + ) model = estimator.fit(train, param_map) metric = evaluator.evaluate( diff --git a/python/pyspark/ml/torch/distributor.py b/python/pyspark/ml/torch/distributor.py index 2056803d61cf4..a4e79b1dcc10b 100644 --- a/python/pyspark/ml/torch/distributor.py +++ b/python/pyspark/ml/torch/distributor.py @@ -49,7 +49,6 @@ LogStreamingServer, ) from pyspark.ml.dl_util import FunctionPickler -from pyspark.ml.util import _get_active_session def _get_resources(session: SparkSession) -> Dict[str, ResourceInformation]: @@ -165,7 +164,7 @@ def __init__( from pyspark.sql.utils import is_remote self.is_remote = is_remote() - self.spark = _get_active_session(self.is_remote) + self.spark = SparkSession.active() # indicate whether the server side is local mode self.is_spark_local_master = False diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 2c90ff3cb7b69..64676947017d0 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -747,16 +747,3 @@ def wrapped(*args: Any, **kwargs: Any) -> Any: return f(*args, **kwargs) return cast(FuncT, wrapped) - - -def _get_active_session(is_remote: bool) -> SparkSession: - if not is_remote: - spark = SparkSession.getActiveSession() - else: - import pyspark.sql.connect.session - - spark = pyspark.sql.connect.session._active_spark_session # type: ignore[assignment] - - if spark is None: - raise RuntimeError("An active SparkSession is required for the distributor.") - return spark diff --git a/python/pyspark/pandas/utils.py b/python/pyspark/pandas/utils.py index c66b3359e77d1..55b9a57ef6187 100644 --- a/python/pyspark/pandas/utils.py +++ b/python/pyspark/pandas/utils.py @@ -478,12 +478,7 @@ def is_testing() -> bool: def default_session() -> SparkSession: - if not is_remote(): - spark = SparkSession.getActiveSession() - else: - from pyspark.sql.connect.session import _active_spark_session - - spark = _active_spark_session # type: ignore[assignment] + spark = SparkSession.getActiveSession() if spark is None: spark = SparkSession.builder.appName("pandas-on-Spark").getOrCreate() diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 9bba0db05e43f..d75a30c561f93 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -18,6 +18,7 @@ check_dependencies(__name__) +import threading import os import warnings from collections.abc import Sized @@ -36,6 +37,7 @@ overload, Iterable, TYPE_CHECKING, + ClassVar, ) import numpy as np @@ -93,14 +95,13 @@ from pyspark.sql.connect.udtf import UDTFRegistration -# `_active_spark_session` stores the active spark connect session created by -# `SparkSession.builder.getOrCreate`. It is used by ML code. -# If sessions are created with `SparkSession.builder.create`, it stores -# The last created session -_active_spark_session = None - - class SparkSession: + # The active SparkSession for the current thread + _active_session: ClassVar[threading.local] = threading.local() + # Reference to the root SparkSession + _default_session: ClassVar[Optional["SparkSession"]] = None + _lock: ClassVar[RLock] = RLock() + class Builder: """Builder for :class:`SparkSession`.""" @@ -176,8 +177,6 @@ def enableHiveSupport(self) -> "SparkSession.Builder": ) def create(self) -> "SparkSession": - global _active_spark_session - has_channel_builder = self._channel_builder is not None has_spark_remote = "spark.remote" in self._options @@ -200,23 +199,26 @@ def create(self) -> "SparkSession": assert spark_remote is not None session = SparkSession(connection=spark_remote) - _active_spark_session = session + SparkSession._set_default_and_active_session(session) return session def getOrCreate(self) -> "SparkSession": - global _active_spark_session - if _active_spark_session is not None: - return _active_spark_session - _active_spark_session = self.create() - return _active_spark_session + with SparkSession._lock: + session = SparkSession.getActiveSession() + if session is None: + session = SparkSession._default_session + if session is None: + session = self.create() + return session _client: SparkConnectClient @classproperty def builder(cls) -> Builder: - """Creates a :class:`Builder` for constructing a :class:`SparkSession`.""" return cls.Builder() + builder.__doc__ = PySparkSession.builder.__doc__ + def __init__(self, connection: Union[str, ChannelBuilder], userId: Optional[str] = None): """ Creates a new SparkSession for the Spark Connect interface. @@ -236,6 +238,38 @@ def __init__(self, connection: Union[str, ChannelBuilder], userId: Optional[str] self._client = SparkConnectClient(connection=connection, user_id=userId) self._session_id = self._client._session_id + @classmethod + def _set_default_and_active_session(cls, session: "SparkSession") -> None: + """ + Set the (global) default :class:`SparkSession`, and (thread-local) + active :class:`SparkSession` when they are not set yet. + """ + with cls._lock: + if cls._default_session is None: + cls._default_session = session + if getattr(cls._active_session, "session", None) is None: + cls._active_session.session = session + + @classmethod + def getActiveSession(cls) -> Optional["SparkSession"]: + return getattr(cls._active_session, "session", None) + + getActiveSession.__doc__ = PySparkSession.getActiveSession.__doc__ + + @classmethod + def active(cls) -> "SparkSession": + session = cls.getActiveSession() + if session is None: + session = cls._default_session + if session is None: + raise PySparkRuntimeError( + error_class="NO_ACTIVE_OR_DEFAULT_SESSION", + message_parameters={}, + ) + return session + + active.__doc__ = PySparkSession.active.__doc__ + def table(self, tableName: str) -> DataFrame: return self.read.table(tableName) @@ -251,6 +285,8 @@ def read(self) -> "DataFrameReader": def readStream(self) -> "DataStreamReader": return DataStreamReader(self) + readStream.__doc__ = PySparkSession.readStream.__doc__ + def _inferSchemaFromList( self, data: Iterable[Any], names: Optional[List[str]] = None ) -> StructType: @@ -601,19 +637,20 @@ def stop(self) -> None: # specifically in Spark Connect the Spark Connect server is designed for # multi-tenancy - the remote client side cannot just stop the server and stop # other remote clients being used from other users. - global _active_spark_session - self.client.close() - _active_spark_session = None - - if "SPARK_LOCAL_REMOTE" in os.environ: - # When local mode is in use, follow the regular Spark session's - # behavior by terminating the Spark Connect server, - # meaning that you can stop local mode, and restart the Spark Connect - # client with a different remote address. - active_session = PySparkSession.getActiveSession() - if active_session is not None: - active_session.stop() - with SparkContext._lock: + with SparkSession._lock: + self.client.close() + if self is SparkSession._default_session: + SparkSession._default_session = None + if self is getattr(SparkSession._active_session, "session", None): + SparkSession._active_session.session = None + + if "SPARK_LOCAL_REMOTE" in os.environ: + # When local mode is in use, follow the regular Spark session's + # behavior by terminating the Spark Connect server, + # meaning that you can stop local mode, and restart the Spark Connect + # client with a different remote address. + if PySparkSession._activeSession is not None: + PySparkSession._activeSession.stop() del os.environ["SPARK_LOCAL_REMOTE"] del os.environ["SPARK_CONNECT_MODE_ENABLED"] if "SPARK_REMOTE" in os.environ: @@ -628,20 +665,18 @@ def is_stopped(self) -> bool: """ return self.client.is_closed - @classmethod - def getActiveSession(cls) -> Any: - raise PySparkNotImplementedError( - error_class="NOT_IMPLEMENTED", message_parameters={"feature": "getActiveSession()"} - ) - @property def conf(self) -> RuntimeConf: return RuntimeConf(self.client) + conf.__doc__ = PySparkSession.conf.__doc__ + @property def streams(self) -> "StreamingQueryManager": return StreamingQueryManager(self) + streams.__doc__ = PySparkSession.streams.__doc__ + def __getattr__(self, name: str) -> Any: if name in ["_jsc", "_jconf", "_jvm", "_jsparkSession"]: raise PySparkAttributeError( @@ -675,6 +710,8 @@ def version(self) -> str: assert result is not None return result + version.__doc__ = PySparkSession.version.__doc__ + @property def client(self) -> "SparkConnectClient": return self._client diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index 2d7e423d3d571..eb0541b936925 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -37,8 +37,7 @@ from pyspark.sql.connect.types import UnparsedDataType from pyspark.sql.types import DataType, StringType from pyspark.sql.udf import UDFRegistration as PySparkUDFRegistration -from pyspark.errors import PySparkTypeError - +from pyspark.errors import PySparkTypeError, PySparkRuntimeError if TYPE_CHECKING: from pyspark.sql.connect._typing import ( @@ -58,14 +57,20 @@ def _create_py_udf( from pyspark.sql.udf import _create_arrow_py_udf if useArrow is None: - from pyspark.sql.connect.session import _active_spark_session - - is_arrow_enabled = ( - False - if _active_spark_session is None - else _active_spark_session.conf.get("spark.sql.execution.pythonUDF.arrow.enabled") - == "true" - ) + is_arrow_enabled = False + try: + from pyspark.sql.connect.session import SparkSession + + session = SparkSession.active() + is_arrow_enabled = ( + str(session.conf.get("spark.sql.execution.pythonUDF.arrow.enabled")).lower() + == "true" + ) + except PySparkRuntimeError as e: + if e.error_class == "NO_ACTIVE_OR_DEFAULT_SESSION": + pass # Just uses the default if no session found. + else: + raise e else: is_arrow_enabled = useArrow diff --git a/python/pyspark/sql/connect/udtf.py b/python/pyspark/sql/connect/udtf.py index 5a95075a65537..c8495626292c5 100644 --- a/python/pyspark/sql/connect/udtf.py +++ b/python/pyspark/sql/connect/udtf.py @@ -68,13 +68,20 @@ def _create_py_udtf( if useArrow is not None: arrow_enabled = useArrow else: - from pyspark.sql.connect.session import _active_spark_session + from pyspark.sql.connect.session import SparkSession arrow_enabled = False - if _active_spark_session is not None: - value = _active_spark_session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled") - if isinstance(value, str) and value.lower() == "true": - arrow_enabled = True + try: + session = SparkSession.active() + arrow_enabled = ( + str(session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled")).lower() + == "true" + ) + except PySparkRuntimeError as e: + if e.error_class == "NO_ACTIVE_OR_DEFAULT_SESSION": + pass # Just uses the default if no session found. + else: + raise e # Create a regular Python UDTF and check for invalid handler class. regular_udtf = _create_udtf(cls, returnType, name, PythonEvalType.SQL_TABLE_UDF, deterministic) @@ -160,17 +167,13 @@ def _build_common_inline_user_defined_table_function( ) def __call__(self, *cols: "ColumnOrName") -> "DataFrame": + from pyspark.sql.connect.session import SparkSession from pyspark.sql.connect.dataframe import DataFrame - from pyspark.sql.connect.session import _active_spark_session - if _active_spark_session is None: - raise PySparkRuntimeError( - "An active SparkSession is required for " - "executing a Python user-defined table function." - ) + session = SparkSession.active() plan = self._build_common_inline_user_defined_table_function(*cols) - return DataFrame.withPlan(plan, _active_spark_session) + return DataFrame.withPlan(plan, session) def asNondeterministic(self) -> "UserDefinedTableFunction": self.deterministic = False diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index ede6318782e0a..9141051fdf830 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -64,8 +64,8 @@ _from_numpy_type, ) from pyspark.errors.exceptions.captured import install_exception_handler -from pyspark.sql.utils import is_timestamp_ntz_preferred, to_str -from pyspark.errors import PySparkValueError, PySparkTypeError +from pyspark.sql.utils import is_timestamp_ntz_preferred, to_str, try_remote_session_classmethod +from pyspark.errors import PySparkValueError, PySparkTypeError, PySparkRuntimeError if TYPE_CHECKING: from pyspark.sql._typing import AtomicValue, RowLike, OptionalPrimitiveType @@ -500,7 +500,7 @@ def getOrCreate(self) -> "SparkSession": ).applyModifiableSettings(session._jsparkSession, self._options) return session - # SparkConnect-specific API + # Spark Connect-specific API def create(self) -> "SparkSession": """Creates a new SparkSession. Can only be used in the context of Spark Connect and will throw an exception otherwise. @@ -510,6 +510,10 @@ def create(self) -> "SparkSession": Returns ------- :class:`SparkSession` + + Notes + ----- + This method will update the default and/or active session if they are not set. """ opts = dict(self._options) if "SPARK_REMOTE" in os.environ or "spark.remote" in opts: @@ -546,7 +550,11 @@ def create(self) -> "SparkSession": # to Python 3.9.6 (https://github.com/python/cpython/pull/28838) @classproperty def builder(cls) -> Builder: - """Creates a :class:`Builder` for constructing a :class:`SparkSession`.""" + """Creates a :class:`Builder` for constructing a :class:`SparkSession`. + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + """ return cls.Builder() _instantiatedSession: ClassVar[Optional["SparkSession"]] = None @@ -632,12 +640,16 @@ def newSession(self) -> "SparkSession": return self.__class__(self._sc, self._jsparkSession.newSession()) @classmethod + @try_remote_session_classmethod def getActiveSession(cls) -> Optional["SparkSession"]: """ Returns the active :class:`SparkSession` for the current thread, returned by the builder .. versionadded:: 3.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Returns ------- :class:`SparkSession` @@ -667,6 +679,30 @@ def getActiveSession(cls) -> Optional["SparkSession"]: else: return None + @classmethod + @try_remote_session_classmethod + def active(cls) -> "SparkSession": + """ + Returns the active or default :class:`SparkSession` for the current thread, returned by + the builder. + + .. versionadded:: 3.5.0 + + Returns + ------- + :class:`SparkSession` + Spark session if an active or default session exists for the current thread. + """ + session = cls.getActiveSession() + if session is None: + session = cls._instantiatedSession + if session is None: + raise PySparkRuntimeError( + error_class="NO_ACTIVE_OR_DEFAULT_SESSION", + message_parameters={}, + ) + return session + @property def sparkContext(self) -> SparkContext: """ @@ -698,6 +734,9 @@ def version(self) -> str: .. versionadded:: 2.0.0 + .. versionchanged:: 3.4.0 + Supports Spark Connect. + Returns ------- str @@ -719,6 +758,9 @@ def conf(self) -> RuntimeConfig: .. versionadded:: 2.0.0 + .. versionchanged:: 3.4.0 + Supports Spark Connect. + Returns ------- :class:`pyspark.sql.conf.RuntimeConfig` @@ -726,7 +768,7 @@ def conf(self) -> RuntimeConfig: Examples -------- >>> spark.conf - + Set a runtime configuration for the session @@ -805,6 +847,9 @@ def udtf(self) -> "UDTFRegistration": .. versionadded:: 3.5.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Returns ------- :class:`UDTFRegistration` @@ -1639,6 +1684,9 @@ def readStream(self) -> DataStreamReader: .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Notes ----- This API is evolving. @@ -1650,7 +1698,7 @@ def readStream(self) -> DataStreamReader: Examples -------- >>> spark.readStream - + The example below uses Rate source that generates rows continuously. After that, we operate a modulo by 3, and then write the stream out to the console. @@ -1672,6 +1720,9 @@ def streams(self) -> "StreamingQueryManager": .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Notes ----- This API is evolving. @@ -1683,7 +1734,7 @@ def streams(self) -> "StreamingQueryManager": Examples -------- >>> spark.streams - + Get the list of active streaming queries diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 065f1585a9f06..0687fc9f31331 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -3043,9 +3043,6 @@ def test_unsupported_functions(self): def test_unsupported_session_functions(self): # SPARK-41934: Disable unsupported functions. - with self.assertRaises(NotImplementedError): - RemoteSparkSession.getActiveSession() - with self.assertRaises(NotImplementedError): RemoteSparkSession.builder.enableHiveSupport() @@ -3331,6 +3328,7 @@ def test_error_stack_trace(self): spark.stop() def test_can_create_multiple_sessions_to_different_remotes(self): + self.spark.stop() self.assertIsNotNone(self.spark._client) # Creates a new remote session. other = PySparkSession.builder.remote("sc://other.remote:114/").create() diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 8b520ed653f8c..d4f56fe822f3e 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import inspect import functools import os from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING, cast, TypeVar, Union, Type @@ -258,6 +259,23 @@ def wrapped(*args: Any, **kwargs: Any) -> Any: return cast(FuncT, wrapped) +def try_remote_session_classmethod(f: FuncT) -> FuncT: + """Mark API supported from Spark Connect.""" + + @functools.wraps(f) + def wrapped(*args: Any, **kwargs: Any) -> Any: + + if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: + from pyspark.sql.connect.session import SparkSession # type: ignore[misc] + + assert inspect.isclass(args[0]) + return getattr(SparkSession, f.__name__)(*args[1:], **kwargs) + else: + return f(*args, **kwargs) + + return cast(FuncT, wrapped) + + def pyspark_column_op( func_name: str, left: "IndexOpsLike", right: Any, fillna: Any = None ) -> Union["SeriesOrIndex", None]: From 630b1777904f15c7ac05c3cd61c0006cd692bc93 Mon Sep 17 00:00:00 2001 From: Siying Dong Date: Tue, 8 Aug 2023 11:11:56 +0900 Subject: [PATCH 61/68] [SPARK-44683][SS] Logging level isn't passed to RocksDB state store provider correctly ### What changes were proposed in this pull request? The logging level is passed into RocksDB in a correct way. ### Why are the changes needed? We pass log4j's log level to RocksDB so that RocksDB debug log can go to log4j. However, we pass in log level after we create the logger. However, the way it is set isn't effective. This has two impacts: (1) setting DEBUG level don't make RocksDB generate DEBUG level logs; (2) setting WARN or ERROR level does prevent INFO level logging, but RocksDB still makes JNI calls to Scala, which is an unnecessary overhead. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manually change the log level and observe the log lines in unit tests. Closes #42354 from siying/rocks_log_level. Authored-by: Siying Dong Signed-off-by: Jungtaek Lim --- .../apache/spark/sql/execution/streaming/state/RocksDB.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index d4366fe732be4..a2868df941178 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -611,8 +611,11 @@ class RocksDB( if (log.isWarnEnabled) dbLogLevel = InfoLogLevel.WARN_LEVEL if (log.isInfoEnabled) dbLogLevel = InfoLogLevel.INFO_LEVEL if (log.isDebugEnabled) dbLogLevel = InfoLogLevel.DEBUG_LEVEL - dbOptions.setLogger(dbLogger) + dbLogger.setInfoLogLevel(dbLogLevel) + // The log level set in dbLogger is effective and the one to dbOptions isn't applied to + // customized logger. We still set it as it might show up in RocksDB config file or logging. dbOptions.setInfoLogLevel(dbLogLevel) + dbOptions.setLogger(dbLogger) logInfo(s"Set RocksDB native logging level to $dbLogLevel") dbLogger } From 7493c5764f9644878babacccd4f688fe13ef84aa Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 8 Aug 2023 04:15:07 +0200 Subject: [PATCH 62/68] [SPARK-43429][CONNECT] Add Default & Active SparkSession for Scala Client ### What changes were proposed in this pull request? This adds the `default` and `active` session variables to `SparkSession`: - `default` session is global value. It is typically the first session created through `getOrCreate`. It can be changed through `set` or `clear`. If the session is closed and it is the `default` session we clear the `default` session. - `active` session is a thread local value. It is typically the first session created in this thread or it inherits is value from its parent thread. It can be changed through `set` or `clear`, please note that these methods operate thread locally, so they won't change the parent or children. If the session is closed and it is the `active` session for the current thread then we clear the active value (only for the current thread!). ### Why are the changes needed? To increase compatibility with the existing SparkSession API in `sql/core`. ### Does this PR introduce _any_ user-facing change? Yes. It adds a couple methods that were missing from the Scala Client. ### How was this patch tested? Added tests to `SparkSessionSuite`. Closes #42367 from hvanhovell/SPARK-43429. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- .../org/apache/spark/sql/SparkSession.scala | 100 ++++++++++-- .../apache/spark/sql/SparkSessionSuite.scala | 144 ++++++++++++++++-- .../CheckConnectJvmClientCompatibility.scala | 2 - 3 files changed, 225 insertions(+), 21 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 355d7edadc788..7367ed153f7db 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import java.io.Closeable import java.net.URI import java.util.concurrent.TimeUnit._ -import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.atomic.{AtomicLong, AtomicReference} import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag @@ -730,6 +730,23 @@ object SparkSession extends Logging { override def load(c: Configuration): SparkSession = create(c) }) + /** The active SparkSession for the current thread. */ + private val activeThreadSession = new InheritableThreadLocal[SparkSession] + + /** Reference to the root SparkSession. */ + private val defaultSession = new AtomicReference[SparkSession] + + /** + * Set the (global) default [[SparkSession]], and (thread-local) active [[SparkSession]] when + * they are not set yet. + */ + private def setDefaultAndActiveSession(session: SparkSession): Unit = { + defaultSession.compareAndSet(null, session) + if (getActiveSession.isEmpty) { + setActiveSession(session) + } + } + /** * Create a new [[SparkSession]] based on the connect client [[Configuration]]. */ @@ -742,8 +759,17 @@ object SparkSession extends Logging { */ private[sql] def onSessionClose(session: SparkSession): Unit = { sessions.invalidate(session.client.configuration) + defaultSession.compareAndSet(session, null) + if (getActiveSession.contains(session)) { + clearActiveSession() + } } + /** + * Creates a [[SparkSession.Builder]] for constructing a [[SparkSession]]. + * + * @since 3.4.0 + */ def builder(): Builder = new Builder() private[sql] lazy val cleaner = { @@ -799,10 +825,15 @@ object SparkSession extends Logging { * * This will always return a newly created session. * + * This method will update the default and/or active session if they are not set. + * * @since 3.5.0 */ def create(): SparkSession = { - tryCreateSessionFromClient().getOrElse(SparkSession.this.create(builder.configuration)) + val session = tryCreateSessionFromClient() + .getOrElse(SparkSession.this.create(builder.configuration)) + setDefaultAndActiveSession(session) + session } /** @@ -811,30 +842,79 @@ object SparkSession extends Logging { * If a session exist with the same configuration that is returned instead of creating a new * session. * + * This method will update the default and/or active session if they are not set. + * * @since 3.5.0 */ def getOrCreate(): SparkSession = { - tryCreateSessionFromClient().getOrElse(sessions.get(builder.configuration)) + val session = tryCreateSessionFromClient() + .getOrElse(sessions.get(builder.configuration)) + setDefaultAndActiveSession(session) + session } } - def getActiveSession: Option[SparkSession] = { - throw new UnsupportedOperationException("getActiveSession is not supported") + /** + * Returns the default SparkSession. + * + * @since 3.5.0 + */ + def getDefaultSession: Option[SparkSession] = Option(defaultSession.get()) + + /** + * Sets the default SparkSession. + * + * @since 3.5.0 + */ + def setDefaultSession(session: SparkSession): Unit = { + defaultSession.set(session) } - def getDefaultSession: Option[SparkSession] = { - throw new UnsupportedOperationException("getDefaultSession is not supported") + /** + * Clears the default SparkSession. + * + * @since 3.5.0 + */ + def clearDefaultSession(): Unit = { + defaultSession.set(null) } + /** + * Returns the active SparkSession for the current thread. + * + * @since 3.5.0 + */ + def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get()) + + /** + * Changes the SparkSession that will be returned in this thread and its children when + * SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives + * an isolated SparkSession. + * + * @since 3.5.0 + */ def setActiveSession(session: SparkSession): Unit = { - throw new UnsupportedOperationException("setActiveSession is not supported") + activeThreadSession.set(session) } + /** + * Clears the active SparkSession for current thread. + * + * @since 3.5.0 + */ def clearActiveSession(): Unit = { - throw new UnsupportedOperationException("clearActiveSession is not supported") + activeThreadSession.remove() } + /** + * Returns the currently active SparkSession, otherwise the default one. If there is no default + * SparkSession, throws an exception. + * + * @since 3.5.0 + */ def active: SparkSession = { - throw new UnsupportedOperationException("active is not supported") + getActiveSession + .orElse(getDefaultSession) + .getOrElse(throw new IllegalStateException("No active or default Spark session found")) } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala index 97fb46bf48af4..f06744399f833 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala @@ -16,6 +16,10 @@ */ package org.apache.spark.sql +import java.util.concurrent.{Executors, Phaser} + +import scala.util.control.NonFatal + import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, MethodDescriptor} import org.apache.spark.sql.connect.client.util.ConnectFunSuite @@ -24,6 +28,10 @@ import org.apache.spark.sql.connect.client.util.ConnectFunSuite * Tests for non-dataframe related SparkSession operations. */ class SparkSessionSuite extends ConnectFunSuite { + private val connectionString1: String = "sc://test.it:17845" + private val connectionString2: String = "sc://test.me:14099" + private val connectionString3: String = "sc://doit:16845" + test("default") { val session = SparkSession.builder().getOrCreate() assert(session.client.configuration.host == "localhost") @@ -32,16 +40,15 @@ class SparkSessionSuite extends ConnectFunSuite { } test("remote") { - val session = SparkSession.builder().remote("sc://test.me:14099").getOrCreate() + val session = SparkSession.builder().remote(connectionString2).getOrCreate() assert(session.client.configuration.host == "test.me") assert(session.client.configuration.port == 14099) session.close() } test("getOrCreate") { - val connectionString = "sc://test.it:17865" - val session1 = SparkSession.builder().remote(connectionString).getOrCreate() - val session2 = SparkSession.builder().remote(connectionString).getOrCreate() + val session1 = SparkSession.builder().remote(connectionString1).getOrCreate() + val session2 = SparkSession.builder().remote(connectionString1).getOrCreate() try { assert(session1 eq session2) } finally { @@ -51,9 +58,8 @@ class SparkSessionSuite extends ConnectFunSuite { } test("create") { - val connectionString = "sc://test.it:17845" - val session1 = SparkSession.builder().remote(connectionString).create() - val session2 = SparkSession.builder().remote(connectionString).create() + val session1 = SparkSession.builder().remote(connectionString1).create() + val session2 = SparkSession.builder().remote(connectionString1).create() try { assert(session1 ne session2) assert(session1.client.configuration == session2.client.configuration) @@ -64,8 +70,7 @@ class SparkSessionSuite extends ConnectFunSuite { } test("newSession") { - val connectionString = "sc://doit:16845" - val session1 = SparkSession.builder().remote(connectionString).create() + val session1 = SparkSession.builder().remote(connectionString3).create() val session2 = session1.newSession() try { assert(session1 ne session2) @@ -92,5 +97,126 @@ class SparkSessionSuite extends ConnectFunSuite { assertThrows[RuntimeException] { session.range(10).count() } + session.close() + } + + test("Default/Active session") { + // Make sure we start with a clean slate. + SparkSession.clearDefaultSession() + SparkSession.clearActiveSession() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.isEmpty) + intercept[IllegalStateException](SparkSession.active) + + // Create a session + val session1 = SparkSession.builder().remote(connectionString1).getOrCreate() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session1)) + assert(SparkSession.active == session1) + + // Create another session... + val session2 = SparkSession.builder().remote(connectionString2).create() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session1)) + SparkSession.setActiveSession(session2) + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + + // Clear sessions + SparkSession.clearDefaultSession() + assert(SparkSession.getDefaultSession.isEmpty) + SparkSession.clearActiveSession() + assert(SparkSession.getDefaultSession.isEmpty) + + // Flip sessions + SparkSession.setActiveSession(session1) + SparkSession.setDefaultSession(session2) + assert(SparkSession.getDefaultSession.contains(session2)) + assert(SparkSession.getActiveSession.contains(session1)) + + // Close session1 + session1.close() + assert(SparkSession.getDefaultSession.contains(session2)) + assert(SparkSession.getActiveSession.isEmpty) + + // Close session2 + session2.close() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.isEmpty) + } + + test("active session in multiple threads") { + SparkSession.clearDefaultSession() + SparkSession.clearActiveSession() + val session1 = SparkSession.builder().remote(connectionString1).create() + val session2 = SparkSession.builder().remote(connectionString1).create() + SparkSession.setActiveSession(session2) + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + + val phaser = new Phaser(2) + val executor = Executors.newFixedThreadPool(2) + def execute(block: Phaser => Unit): java.util.concurrent.Future[Boolean] = { + executor.submit[Boolean] { () => + try { + block(phaser) + true + } catch { + case NonFatal(e) => + phaser.forceTermination() + throw e + } + } + } + + try { + val script1 = execute { phaser => + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + session1.close() + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.contains(session2)) + SparkSession.clearActiveSession() + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.isEmpty) + } + val script2 = execute { phaser => + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + SparkSession.clearActiveSession() + val internalSession = SparkSession.builder().remote(connectionString3).getOrCreate() + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(internalSession)) + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.contains(internalSession)) + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.contains(internalSession)) + internalSession.close() + assert(SparkSession.getActiveSession.isEmpty) + } + assert(script1.get()) + assert(script2.get()) + assert(SparkSession.getActiveSession.contains(session2)) + session2.close() + assert(SparkSession.getActiveSession.isEmpty) + } finally { + executor.shutdown() + } } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 6e577e0f21257..2bf9c41fb2cbd 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -207,8 +207,6 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.RelationalGroupedDataset.apply"), // SparkSession - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.clearDefaultSession"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.setDefaultSession"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sparkContext"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sharedState"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sessionState"), From aa1261dc129618d27a1bdc743a5fdd54219f7c01 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Mon, 7 Aug 2023 19:16:38 -0700 Subject: [PATCH 63/68] [SPARK-44641][SQL] Incorrect result in certain scenarios when SPJ is not triggered ### What changes were proposed in this pull request? This PR makes sure we use unique partition values when calculating the final partitions in `BatchScanExec`, to make sure no duplicated partitions are generated. ### Why are the changes needed? When `spark.sql.sources.v2.bucketing.pushPartValues.enabled` and `spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` are enabled, and SPJ is not triggered, currently Spark will generate incorrect/duplicated results. This is because with both configs enabled, Spark will delay the partition grouping until the time it calculates the final partitions used by the input RDD. To calculate the partitions, it uses partition values from the `KeyGroupedPartitioning` to find out the right ordering for the partitions. However, since grouping is not done when the partition values is computed, there could be duplicated partition values. This means the result could contain duplicated partitions too. ### Does this PR introduce _any_ user-facing change? No, this is a bug fix. ### How was this patch tested? Added a new test case for this scenario. Closes #42324 from sunchao/SPARK-44641. Authored-by: Chao Sun Signed-off-by: Chao Sun --- .../plans/physical/partitioning.scala | 9 ++- .../datasources/v2/BatchScanExec.scala | 9 ++- .../KeyGroupedPartitioningSuite.scala | 56 +++++++++++++++++++ 3 files changed, 72 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index bd8ba54ddd736..456005768bd42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -313,7 +313,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) * by `expressions`. `partitionValues`, if defined, should contain value of partition key(s) in * ascending order, after evaluated by the transforms in `expressions`, for each input partition. * In addition, its length must be the same as the number of input partitions (and thus is a 1-1 - * mapping), and each row in `partitionValues` must be unique. + * mapping). The `partitionValues` may contain duplicated partition values. * * For example, if `expressions` is `[years(ts_col)]`, then a valid value of `partitionValues` is * `[0, 1, 2]`, which represents 3 input partitions with distinct partition values. All rows @@ -355,6 +355,13 @@ case class KeyGroupedPartitioning( override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = KeyGroupedShuffleSpec(this, distribution) + + lazy val uniquePartitionValues: Seq[InternalRow] = { + partitionValues + .map(InternalRowComparableWrapper(_, expressions)) + .distinct + .map(_.row) + } } object KeyGroupedPartitioning { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 4b53819739262..eba3c71f871e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -190,10 +190,17 @@ case class BatchScanExec( Seq.fill(numSplits)(Seq.empty)) } } else { + // either `commonPartitionValues` is not defined, or it is defined but + // `applyPartialClustering` is false. val partitionMapping = groupedPartitions.map { case (row, parts) => InternalRowComparableWrapper(row, p.expressions) -> parts }.toMap - finalPartitions = p.partitionValues.map { partValue => + + // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there + // could exist duplicated partition values, as partition grouping is not done + // at the beginning and postponed to this method. It is important to use unique + // partition values here so that grouped partitions won't get duplicated. + finalPartitions = p.uniquePartitionValues.map { partValue => // Use empty partition for those partition values that are not present partitionMapping.getOrElse( InternalRowComparableWrapper(partValue, p.expressions), Seq.empty) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 880c30ba9f98d..8461f528277c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -1039,4 +1039,60 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } } + + test("SPARK-44641: duplicated records when SPJ is not triggered") { + val items_partitions = Array(bucket(8, "id")) + createTable(items, items_schema, items_partitions) + sql(s""" + INSERT INTO testcat.ns.$items VALUES + (1, 'aa', 40.0, cast('2020-01-01' as timestamp)), + (1, 'aa', 41.0, cast('2020-01-15' as timestamp)), + (2, 'bb', 10.0, cast('2020-01-01' as timestamp)), + (2, 'bb', 10.5, cast('2020-01-01' as timestamp)), + (3, 'cc', 15.5, cast('2020-02-01' as timestamp))""") + + val purchases_partitions = Array(bucket(8, "item_id")) + createTable(purchases, purchases_schema, purchases_partitions) + sql(s"""INSERT INTO testcat.ns.$purchases VALUES + (1, 42.0, cast('2020-01-01' as timestamp)), + (1, 44.0, cast('2020-01-15' as timestamp)), + (1, 45.0, cast('2020-01-15' as timestamp)), + (2, 11.0, cast('2020-01-01' as timestamp)), + (3, 19.5, cast('2020-02-01' as timestamp))""") + + Seq(true, false).foreach { pushDownValues => + Seq(true, false).foreach { partiallyClusteredEnabled => + withSQLConf( + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClusteredEnabled.toString) { + + // join keys are not the same as the partition keys, therefore SPJ is not triggered. + val df = sql( + s""" + SELECT id, name, i.price as purchase_price, p.item_id, p.price as sale_price + FROM testcat.ns.$items i JOIN testcat.ns.$purchases p + ON i.arrive_time = p.time ORDER BY id, purchase_price, p.item_id, sale_price + """) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.nonEmpty, "shuffle should exist when SPJ is not used") + + checkAnswer(df, + Seq( + Row(1, "aa", 40.0, 1, 42.0), + Row(1, "aa", 40.0, 2, 11.0), + Row(1, "aa", 41.0, 1, 44.0), + Row(1, "aa", 41.0, 1, 45.0), + Row(2, "bb", 10.0, 1, 42.0), + Row(2, "bb", 10.0, 2, 11.0), + Row(2, "bb", 10.5, 1, 42.0), + Row(2, "bb", 10.5, 2, 11.0), + Row(3, "cc", 15.5, 3, 19.5) + ) + ) + } + } + } + } } From 6dadd188f3652816c291919a2413f73c13bb1b47 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Tue, 8 Aug 2023 11:04:53 +0800 Subject: [PATCH 64/68] [SPARK-44554][INFRA] Make Python linter related checks pass of branch-3.3/3.4 daily testing ### What changes were proposed in this pull request? The daily testing of `branch-3.3/3.4` uses the same yml file as the master now and the upgrade to `MyPy` in https://github.com/apache/spark/pull/41690 resulted in Python linter check failure of `branch-3.3/3.4`, - branch-3.3: https://github.com/apache/spark/actions/runs/5677524469/job/15386025539 - branch-3.4: https://github.com/apache/spark/actions/runs/5678626664/job/15389273919 image So this pr do the following change for workaround: 1. Install different Python linter dependencies for `branch-3.3/3.4`, the dependency list comes from the corresponding branch to ensure compatibility with the version 2. Skip `Install dependencies for Python code generation check` and `Python code generation check` for `branch-3.3/3.4` due to they do not use `Buf remote plugins` and `Buf remote generation` is no longer supported. Meanwhile, the protobuf files in the branch generally do not change, so we can skip this check. ### Why are the changes needed? Make Python linter related checks pass of branch-3.3/3.4 daily testing ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Pass GitHub Actions - Manually checked branch-3.4, the newly added condition should be ok Closes #42167 from LuciferYang/SPARK-44554. Lead-authored-by: yangjie01 Co-authored-by: YangJie Signed-off-by: yangjie01 --- .github/workflows/build_and_test.yml | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index cd68c0904d9a4..b4559dea42bb9 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -657,7 +657,22 @@ jobs: - name: Spark connect jvm client mima check if: inputs.branch != 'branch-3.3' run: ./dev/connect-jvm-client-mima-check + - name: Install Python linter dependencies for branch-3.3 + if: inputs.branch == 'branch-3.3' + run: | + # SPARK-44554: Copy from https://github.com/apache/spark/blob/073d0b60d31bf68ebacdc005f59b928a5902670f/.github/workflows/build_and_test.yml#L501-L508 + # Should delete this section after SPARK 3.3 EOL. + python3.9 -m pip install 'flake8==3.9.0' pydata_sphinx_theme 'mypy==0.920' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' numpydoc 'jinja2<3.0.0' 'black==21.12b0' + python3.9 -m pip install 'pandas-stubs==1.2.0.53' + - name: Install Python linter dependencies for branch-3.4 + if: inputs.branch == 'branch-3.4' + run: | + # SPARK-44554: Copy from https://github.com/apache/spark/blob/a05c27e85829fe742c1828507a1fd180cdc84b54/.github/workflows/build_and_test.yml#L571-L578 + # Should delete this section after SPARK 3.4 EOL. + python3.9 -m pip install 'flake8==3.9.0' pydata_sphinx_theme 'mypy==0.920' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' numpydoc 'jinja2<3.0.0' 'black==22.6.0' + python3.9 -m pip install 'pandas-stubs==1.2.0.53' ipython 'grpcio==1.48.1' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' - name: Install Python linter dependencies + if: inputs.branch != 'branch-3.3' && inputs.branch != 'branch-3.4' run: | # TODO(SPARK-32407): Sphinx 3.1+ does not correctly index nested classes. # See also https://github.com/sphinx-doc/sphinx/issues/7551. @@ -668,6 +683,7 @@ jobs: - name: Python linter run: PYTHON_EXECUTABLE=python3.9 ./dev/lint-python - name: Install dependencies for Python code generation check + if: inputs.branch != 'branch-3.3' && inputs.branch != 'branch-3.4' run: | # See more in "Installation" https://docs.buf.build/installation#tarball curl -LO https://github.com/bufbuild/buf/releases/download/v1.24.0/buf-Linux-x86_64.tar.gz @@ -676,6 +692,7 @@ jobs: rm buf-Linux-x86_64.tar.gz python3.9 -m pip install 'protobuf==3.20.3' 'mypy-protobuf==3.3.0' - name: Python code generation check + if: inputs.branch != 'branch-3.3' && inputs.branch != 'branch-3.4' run: if test -f ./dev/connect-check-protos.py; then PATH=$PATH:$HOME/buf/bin PYTHON_EXECUTABLE=python3.9 ./dev/connect-check-protos.py; fi - name: Install JavaScript linter dependencies run: | From 25053d98186489d9f2061c9b815a5a33f7e309c4 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Tue, 8 Aug 2023 11:06:21 +0800 Subject: [PATCH 65/68] [SPARK-44689][CONNECT] Make the exception handling of function `SparkConnectPlanner#unpackScalarScalaUDF` more universal ### What changes were proposed in this pull request? This PR changes the exception handling in the `unpackScalarScalaUD` function in `SparkConnectPlanner` from determining the exception type based on a fixed nesting level to using Guava `Throwables` to get the root cause and then determining the type of the root cause. This makes it compatible with differences between different Java versions. ### Why are the changes needed? The following failure occurred when testing `UDFClassLoadingE2ESuite` in Java 17 daily test: https://github.com/apache/spark/actions/runs/5766913899/job/15635782831 ``` [info] UDFClassLoadingE2ESuite: [info] - update class loader after stubbing: new session *** FAILED *** (101 milliseconds) [info] "Exception in SerializedLambda.readResolve" did not contain "java.lang.NoSuchMethodException: org.apache.spark.sql.connect.client.StubClassDummyUdf" (UDFClassLoadingE2ESuite.scala:57) ... [info] - update class loader after stubbing: same session *** FAILED *** (52 milliseconds) [info] "Exception in SerializedLambda.readResolve" did not contain "java.lang.NoSuchMethodException: org.apache.spark.sql.connect.client.StubClassDummyUdf" (UDFClassLoadingE2ESuite.scala:73) ... ``` After analysis, it was found that there are differences in the exception stack generated on the server side between Java 8 and Java 17: - Java 8 ``` java.io.IOException: unexpected exception type at java.io.ObjectStreamClass.throwMiscException(ObjectStreamClass.java:1750) at java.io.ObjectStreamClass.invokeReadResolve(ObjectStreamClass.java:1280) at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2222) at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669) at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2431) at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2355) at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2213) at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669) at java.io.ObjectInputStream.readObject(ObjectInputStream.java:503) at java.io.ObjectInputStream.readObject(ObjectInputStream.java:461) at org.apache.spark.util.SparkSerDeUtils.deserialize(SparkSerDeUtils.scala:50) at org.apache.spark.util.SparkSerDeUtils.deserialize$(SparkSerDeUtils.scala:41) at org.apache.spark.util.Utils$.deserialize(Utils.scala:95) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.unpackScalarScalaUDF(SparkConnectPlanner.scala:1516) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.org$apache$spark$sql$connect$planner$SparkConnectPlanner$$unpackUdf(SparkConnectPlanner.scala:1507) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformScalarScalaFunction(SparkConnectPlanner.scala:1544) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.handleRegisterScalarScalaUDF(SparkConnectPlanner.scala:2565) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.handleRegisterUserDefinedFunction(SparkConnectPlanner.scala:2492) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.process(SparkConnectPlanner.scala:2363) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.handleCommand(ExecuteThreadRunner.scala:202) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.$anonfun$executeInternal$1(ExecuteThreadRunner.scala:158) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.$anonfun$executeInternal$1$adapted(ExecuteThreadRunner.scala:132) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$2(SessionHolder.scala:184) at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$1(SessionHolder.scala:184) at org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withContextClassLoader$1(SessionHolder.scala:171) at org.apache.spark.util.Utils$.withContextClassLoader(Utils.scala:179) at org.apache.spark.sql.connect.service.SessionHolder.withContextClassLoader(SessionHolder.scala:170) at org.apache.spark.sql.connect.service.SessionHolder.withSession(SessionHolder.scala:183) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.executeInternal(ExecuteThreadRunner.scala:132) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.org$apache$spark$sql$connect$execution$ExecuteThreadRunner$$execute(ExecuteThreadRunner.scala:84) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner$ExecutionThread.run(ExecuteThreadRunner.scala:227) Caused by: java.lang.NoSuchMethodException: org.apache.spark.sql.connect.client.StubClassDummyUdf.$deserializeLambda$(java.lang.invoke.SerializedLambda) at java.lang.Class.getDeclaredMethod(Class.java:2130) at java.lang.invoke.SerializedLambda$1.run(SerializedLambda.java:224) at java.lang.invoke.SerializedLambda$1.run(SerializedLambda.java:221) at java.security.AccessController.doPrivileged(Native Method) at java.lang.invoke.SerializedLambda.readResolve(SerializedLambda.java:221) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at java.io.ObjectStreamClass.invokeReadResolve(ObjectStreamClass.java:1274) ... 31 more ``` - Java 17 ``` java.lang.RuntimeException: Exception in SerializedLambda.readResolve at java.base/java.lang.invoke.SerializedLambda.readResolve(SerializedLambda.java:288) at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:77) at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.base/java.lang.reflect.Method.invoke(Method.java:568) at java.base/java.io.ObjectStreamClass.invokeReadResolve(ObjectStreamClass.java:1190) at java.base/java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2266) at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1733) at java.base/java.io.ObjectInputStream$FieldValues.(ObjectInputStream.java:2606) at java.base/java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2457) at java.base/java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2257) at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1733) at java.base/java.io.ObjectInputStream.readObject(ObjectInputStream.java:509) at java.base/java.io.ObjectInputStream.readObject(ObjectInputStream.java:467) at org.apache.spark.util.SparkSerDeUtils.deserialize(SparkSerDeUtils.scala:50) at org.apache.spark.util.SparkSerDeUtils.deserialize$(SparkSerDeUtils.scala:41) at org.apache.spark.util.Utils$.deserialize(Utils.scala:95) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.unpackScalarScalaUDF(SparkConnectPlanner.scala:1517) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.org$apache$spark$sql$connect$planner$SparkConnectPlanner$$unpackUdf(SparkConnectPlanner.scala:1507) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformScalarScalaFunction(SparkConnectPlanner.scala:1552) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.handleRegisterScalarScalaUDF(SparkConnectPlanner.scala:2573) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.handleRegisterUserDefinedFunction(SparkConnectPlanner.scala:2500) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.process(SparkConnectPlanner.scala:2371) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.handleCommand(ExecuteThreadRunner.scala:202) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.$anonfun$executeInternal$1(ExecuteThreadRunner.scala:158) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.$anonfun$executeInternal$1$adapted(ExecuteThreadRunner.scala:132) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$2(SessionHolder.scala:184) at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$1(SessionHolder.scala:184) at org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withContextClassLoader$1(SessionHolder.scala:171) at org.apache.spark.util.Utils$.withContextClassLoader(Utils.scala:179) at org.apache.spark.sql.connect.service.SessionHolder.withContextClassLoader(SessionHolder.scala:170) at org.apache.spark.sql.connect.service.SessionHolder.withSession(SessionHolder.scala:183) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.executeInternal(ExecuteThreadRunner.scala:132) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.org$apache$spark$sql$connect$execution$ExecuteThreadRunner$$execute(ExecuteThreadRunner.scala:84) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner$ExecutionThread.run(ExecuteThreadRunner.scala:227) Caused by: java.security.PrivilegedActionException: java.lang.NoSuchMethodException: org.apache.spark.sql.connect.client.StubClassDummyUdf.$deserializeLambda$(java.lang.invoke.SerializedLambda) at java.base/java.security.AccessController.doPrivileged(AccessController.java:573) at java.base/java.lang.invoke.SerializedLambda.readResolve(SerializedLambda.java:269) ... 36 more Caused by: java.lang.NoSuchMethodException: org.apache.spark.sql.connect.client.StubClassDummyUdf.$deserializeLambda$(java.lang.invoke.SerializedLambda) at java.base/java.lang.Class.getDeclaredMethod(Class.java:2675) at java.base/java.lang.invoke.SerializedLambda$1.run(SerializedLambda.java:272) at java.base/java.lang.invoke.SerializedLambda$1.run(SerializedLambda.java:269) at java.base/java.security.AccessController.doPrivileged(AccessController.java:569) ... 37 more ``` While their root exceptions are both `NoSuchMethodException`, the levels of nesting are different. We can add an exception check branch to make it compatible with Java 17, for example: ```scala case e: IOException if e.getCause.isInstanceOf[NoSuchMethodException] => throw new ClassNotFoundException(... ${e.getCause} ...) case e: RuntimeException if e.getCause != null && e.getCause.getCause.isInstanceOf[NoSuchMethodException] => throw new ClassNotFoundException(... ${e.getCause.getCause} ...) ``` But if future Java versions change the nested levels of exceptions again, this will necessitate another modification of this part of the code. Therefore, this PR has been revised to fetch the root cause of the exception and conduct a type check on the root cause to make it as universal as possible. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Pass Git Hub Actions - Manually check with Java 17 ``` java -version openjdk version "17.0.8" 2023-07-18 LTS OpenJDK Runtime Environment Zulu17.44+15-CA (build 17.0.8+7-LTS) OpenJDK 64-Bit Server VM Zulu17.44+15-CA (build 17.0.8+7-LTS, mixed mode, sharing) ``` run ``` build/sbt clean "connect-client-jvm/testOnly *UDFClassLoadingE2ESuite" -Phive ``` Before ``` [info] UDFClassLoadingE2ESuite: [info] - update class loader after stubbing: new session *** FAILED *** (60 milliseconds) [info] "Exception in SerializedLambda.readResolve" did not contain "java.lang.NoSuchMethodException: org.apache.spark.sql.connect.client.StubClassDummyUdf" (UDFClassLoadingE2ESuite.scala:57) ... [info] - update class loader after stubbing: same session *** FAILED *** (15 milliseconds) [info] "Exception in SerializedLambda.readResolve" did not contain "java.lang.NoSuchMethodException: org.apache.spark.sql.connect.client.StubClassDummyUdf" (UDFClassLoadingE2ESuite.scala:73) ... [info] Run completed in 9 seconds, 565 milliseconds. [info] Total number of tests run: 2 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 0, failed 2, canceled 0, ignored 0, pending 0 [info] *** 2 TESTS FAILED *** [error] Failed tests: [error] org.apache.spark.sql.connect.client.UDFClassLoadingE2ESuite [error] (connect-client-jvm / Test / testOnly) sbt.TestsFailedException: Tests unsuccessful ``` After ``` [info] UDFClassLoadingE2ESuite: [info] - update class loader after stubbing: new session (116 milliseconds) [info] - update class loader after stubbing: same session (41 milliseconds) [info] Run completed in 9 seconds, 781 milliseconds. [info] Total number of tests run: 2 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 2, failed 0, canceled 0, ignored 0, pending 0 [info] All tests passed. ``` Closes #42360 from LuciferYang/unpackScalarScalaUDF-exception-java17. Authored-by: yangjie01 Signed-off-by: yangjie01 --- .../connect/planner/SparkConnectPlanner.scala | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 7136476b515f9..f70a17e580a3e 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.connect.planner -import java.io.IOException - import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Try +import com.google.common.base.Throwables import com.google.common.collect.{Lists, Maps} import com.google.protobuf.{Any => ProtoAny, ByteString} import io.grpc.{Context, Status, StatusRuntimeException} @@ -1518,11 +1517,15 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { logDebug(s"Unpack using class loader: ${Utils.getContextOrSparkClassLoader}") Utils.deserialize[T](fun.getPayload.toByteArray, Utils.getContextOrSparkClassLoader) } catch { - case e: IOException if e.getCause.isInstanceOf[NoSuchMethodException] => - throw new ClassNotFoundException( - s"Failed to load class correctly due to ${e.getCause}. " + - "Make sure the artifact where the class is defined is installed by calling" + - " session.addArtifact.") + case t: Throwable => + Throwables.getRootCause(t) match { + case nsm: NoSuchMethodException => + throw new ClassNotFoundException( + s"Failed to load class correctly due to $nsm. " + + "Make sure the artifact where the class is defined is installed by calling" + + " session.addArtifact.") + case _ => throw t + } } } From 590b77f76284ad03ad8b3b6d30b23983c66513fc Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Tue, 8 Aug 2023 11:09:58 +0800 Subject: [PATCH 66/68] [SPARK-44005][PYTHON] Improve error messages for regular Python UDTFs that return non-tuple values ### What changes were proposed in this pull request? This PR improves error messages for regular Python UDTFs when the result rows are not one of tuple, list and dict. Note this is supported when arrow optimization is enabled. ### Why are the changes needed? To make Python UDTFs more user friendly. ### Does this PR introduce _any_ user-facing change? Yes. ``` class TestUDTF: def eval(self, a: int): yield a ``` Before this PR, this will fail with this error `Unexpected tuple 1 with StructType` After this PR, this will have a more user-friendly error: `[UDTF_INVALID_OUTPUT_ROW_TYPE] The type of an individual output row in the UDTF is invalid. Each row should be a tuple, list, or dict, but got 'int'. Please make sure that the output rows are of the correct type.` ### How was this patch tested? Existing UTs. Closes #42353 from allisonwang-db/spark-44005-non-tuple-return-val. Authored-by: allisonwang-db Signed-off-by: Ruifeng Zheng --- python/pyspark/errors/error_classes.py | 5 +++++ python/pyspark/sql/tests/test_udtf.py | 26 +++++++++++--------------- python/pyspark/worker.py | 12 +++++++++--- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index 24885e94d3255..bc32afeb87a9f 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -743,6 +743,11 @@ "User defined table function encountered an error in the '' method: " ] }, + "UDTF_INVALID_OUTPUT_ROW_TYPE" : { + "message" : [ + "The type of an individual output row in the UDTF is invalid. Each row should be a tuple, list, or dict, but got ''. Please make sure that the output rows are of the correct type." + ] + }, "UDTF_RETURN_NOT_ITERABLE" : { "message" : [ "The return value of the UDTF is invalid. It should be an iterable (e.g., generator or list), but got ''. Please make sure that the UDTF returns one of these types." diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index b2f473996bcb6..300067716e9de 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -163,24 +163,21 @@ def eval(self, a: int, b: int): self.assertEqual(rows, [Row(a=1, b=2), Row(a=2, b=3)]) def test_udtf_eval_returning_non_tuple(self): + @udtf(returnType="a: int") class TestUDTF: def eval(self, a: int): yield a - func = udtf(TestUDTF, returnType="a: int") - # TODO(SPARK-44005): improve this error message - with self.assertRaisesRegex(PythonException, "Unexpected tuple 1 with StructType"): - func(lit(1)).collect() + with self.assertRaisesRegex(PythonException, "UDTF_INVALID_OUTPUT_ROW_TYPE"): + TestUDTF(lit(1)).collect() - def test_udtf_eval_returning_non_generator(self): + @udtf(returnType="a: int") class TestUDTF: def eval(self, a: int): return (a,) - func = udtf(TestUDTF, returnType="a: int") - # TODO(SPARK-44005): improve this error message - with self.assertRaisesRegex(PythonException, "Unexpected tuple 1 with StructType"): - func(lit(1)).collect() + with self.assertRaisesRegex(PythonException, "UDTF_INVALID_OUTPUT_ROW_TYPE"): + TestUDTF(lit(1)).collect() def test_udtf_with_invalid_return_value(self): @udtf(returnType="x: int") @@ -1852,21 +1849,20 @@ def eval(self): self.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", old_value) def test_udtf_eval_returning_non_tuple(self): + @udtf(returnType="a: int") class TestUDTF: def eval(self, a: int): yield a - func = udtf(TestUDTF, returnType="a: int") # When arrow is enabled, it can handle non-tuple return value. - self.assertEqual(func(lit(1)).collect(), [Row(a=1)]) + assertDataFrameEqual(TestUDTF(lit(1)), [Row(a=1)]) - def test_udtf_eval_returning_non_generator(self): + @udtf(returnType="a: int") class TestUDTF: def eval(self, a: int): - return (a,) + return [a] - func = udtf(TestUDTF, returnType="a: int") - self.assertEqual(func(lit(1)).collect(), [Row(a=1)]) + assertDataFrameEqual(TestUDTF(lit(1)), [Row(a=1)]) def test_numeric_output_type_casting(self): class TestUDTF: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index b32e20e3b0418..6f27400387e72 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -648,9 +648,8 @@ def wrap_udtf(f, return_type): return_type_size = len(return_type) def verify_and_convert_result(result): - # TODO(SPARK-44005): support returning non-tuple values - if result is not None and hasattr(result, "__len__"): - if len(result) != return_type_size: + if result is not None: + if hasattr(result, "__len__") and len(result) != return_type_size: raise PySparkRuntimeError( error_class="UDTF_RETURN_SCHEMA_MISMATCH", message_parameters={ @@ -658,6 +657,13 @@ def verify_and_convert_result(result): "actual": str(len(result)), }, ) + + if not (isinstance(result, (list, dict, tuple)) or hasattr(result, "__dict__")): + raise PySparkRuntimeError( + error_class="UDTF_INVALID_OUTPUT_ROW_TYPE", + message_parameters={"type": type(result).__name__}, + ) + return toInternal(result) # Evaluate the function and return a tuple back to the executor. From b4b91212b1d4ce8f47f9e1abeb26b06122c01f13 Mon Sep 17 00:00:00 2001 From: Shuyou Dong Date: Tue, 8 Aug 2023 12:17:53 +0900 Subject: [PATCH 67/68] [SPARK-44703][CORE] Log eventLog rewrite duration when compact old event log files ### What changes were proposed in this pull request? Log eventLog rewrite duration when compact old event log files. ### Why are the changes needed? When enable `spark.eventLog.rolling.enabled` and the number of eventLog files exceeds the value of `spark.history.fs.eventLog.rolling.maxFilesToRetain`, HistoryServer will compact the old event log files into one compact file. Currently there is no log the rewrite duration in rewrite method, this metric is useful for understand the compact duration, so we need add logs in the method. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Manual test. Closes #42378 from shuyouZZ/SPARK-44703. Authored-by: Shuyou Dong Signed-off-by: Jungtaek Lim --- .../apache/spark/deploy/history/EventLogFileCompactor.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/EventLogFileCompactor.scala b/core/src/main/scala/org/apache/spark/deploy/history/EventLogFileCompactor.scala index 8558f765175fc..27040e83533ff 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/EventLogFileCompactor.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/EventLogFileCompactor.scala @@ -149,6 +149,7 @@ class EventLogFileCompactor( val logWriter = new CompactedEventLogFileWriter(lastIndexEventLogPath, "dummy", None, lastIndexEventLogPath.getParent.toUri, sparkConf, hadoopConf) + val startTime = System.currentTimeMillis() logWriter.start() eventLogFiles.foreach { file => EventFilter.applyFilterToFile(fs, filters, file.getPath, @@ -158,6 +159,8 @@ class EventLogFileCompactor( ) } logWriter.stop() + val duration = System.currentTimeMillis() - startTime + logInfo(s"Finished rewriting eventLog files to ${logWriter.logPath} took $duration ms.") logWriter.logPath } From d2b60ff51fabdb38899e649aa2e700112534d79c Mon Sep 17 00:00:00 2001 From: itholic Date: Tue, 8 Aug 2023 16:16:11 +0900 Subject: [PATCH 68/68] [SPARK-43567][PS] Support `use_na_sentinel` for `factorize` ### What changes were proposed in this pull request? This PR proposes to support `use_na_sentinel` for `factorize`. ### Why are the changes needed? To match the behavior with [pandas 2](https://pandas.pydata.org/docs/dev/whatsnew/v2.0.0.html) ### Does this PR introduce _any_ user-facing change? Yes, the `na_sentinel` is removed in favor of `use_na_sentinel`. ### How was this patch tested? Enabling the existing tests. Closes #42270 from itholic/pandas_use_na_sentinel. Authored-by: itholic Signed-off-by: Hyukjin Kwon --- .../migration_guide/pyspark_upgrade.rst | 1 + python/pyspark/pandas/base.py | 39 +++++++------------ .../connect/series/test_parity_compute.py | 4 ++ .../pandas/tests/indexes/test_category.py | 8 +--- .../pandas/tests/series/test_compute.py | 20 ++++------ 5 files changed, 29 insertions(+), 43 deletions(-) diff --git a/python/docs/source/migration_guide/pyspark_upgrade.rst b/python/docs/source/migration_guide/pyspark_upgrade.rst index 7a691ee264571..d26f1cbbe0dc4 100644 --- a/python/docs/source/migration_guide/pyspark_upgrade.rst +++ b/python/docs/source/migration_guide/pyspark_upgrade.rst @@ -29,6 +29,7 @@ Upgrading from PySpark 3.5 to 4.0 * In Spark 4.0, ``Series.append`` has been removed from pandas API on Spark, use ``ps.concat`` instead. * In Spark 4.0, ``DataFrame.mad`` has been removed from pandas API on Spark. * In Spark 4.0, ``Series.mad`` has been removed from pandas API on Spark. +* In Spark 4.0, ``na_sentinel`` parameter from ``Index.factorize`` and `Series.factorize`` has been removed from pandas API on Spark, use ``use_na_sentinel`` instead. Upgrading from PySpark 3.3 to 3.4 diff --git a/python/pyspark/pandas/base.py b/python/pyspark/pandas/base.py index 2de260e6e9351..0685af769872a 100644 --- a/python/pyspark/pandas/base.py +++ b/python/pyspark/pandas/base.py @@ -1614,7 +1614,7 @@ def take(self: IndexOpsLike, indices: Sequence[int]) -> IndexOpsLike: return cast(IndexOpsLike, self._psdf.iloc[indices].index) def factorize( - self: IndexOpsLike, sort: bool = True, na_sentinel: Optional[int] = -1 + self: IndexOpsLike, sort: bool = True, use_na_sentinel: bool = True ) -> Tuple[IndexOpsLike, pd.Index]: """ Encode the object as an enumerated type or categorical variable. @@ -1625,11 +1625,11 @@ def factorize( Parameters ---------- sort : bool, default True - na_sentinel : int or None, default -1 - Value to mark "not found". If None, will not drop the NaN - from the uniques of the values. - - .. deprecated:: 3.4.0 + use_na_sentinel : bool, default True + If True, the sentinel -1 will be used for NaN values, effectively assigning them + a distinct category. If False, NaN values will be encoded as non-negative integers, + treating them as unique categories in the encoding process and retaining them in the + set of unique categories in the data. Returns ------- @@ -1658,7 +1658,7 @@ def factorize( >>> uniques Index(['a', 'b', 'c'], dtype='object') - >>> codes, uniques = psser.factorize(na_sentinel=None) + >>> codes, uniques = psser.factorize(use_na_sentinel=False) >>> codes 0 1 1 3 @@ -1669,17 +1669,6 @@ def factorize( >>> uniques Index(['a', 'b', 'c', None], dtype='object') - >>> codes, uniques = psser.factorize(na_sentinel=-2) - >>> codes - 0 1 - 1 -2 - 2 0 - 3 2 - 4 1 - dtype: int32 - >>> uniques - Index(['a', 'b', 'c'], dtype='object') - For Index: >>> psidx = ps.Index(['b', None, 'a', 'c', 'b']) @@ -1691,8 +1680,8 @@ def factorize( """ from pyspark.pandas.series import first_series - assert (na_sentinel is None) or isinstance(na_sentinel, int) assert sort is True + use_na_sentinel = -1 if use_na_sentinel else False # type: ignore[assignment] warnings.warn( "Argument `na_sentinel` will be removed in 4.0.0.", @@ -1716,7 +1705,7 @@ def factorize( scol = map_scol[self.spark.column] codes, uniques = self._with_new_scol( scol.alias(self._internal.data_spark_column_names[0]) - ).factorize(na_sentinel=na_sentinel) + ).factorize(use_na_sentinel=use_na_sentinel) return codes, uniques.astype(self.dtype) uniq_sdf = self._internal.spark_frame.select(self.spark.column).distinct() @@ -1743,13 +1732,13 @@ def factorize( # Constructs `unique_to_code` mapping non-na unique to code unique_to_code = {} - if na_sentinel is not None: - na_sentinel_code = na_sentinel + if use_na_sentinel: + na_sentinel_code = use_na_sentinel code = 0 for unique in uniques_list: if pd.isna(unique): - if na_sentinel is None: - na_sentinel_code = code + if not use_na_sentinel: + na_sentinel_code = code # type: ignore[assignment] else: unique_to_code[unique] = code code += 1 @@ -1767,7 +1756,7 @@ def factorize( codes = self._with_new_scol(new_scol.alias(self._internal.data_spark_column_names[0])) - if na_sentinel is not None: + if use_na_sentinel: # Drops the NaN from the uniques of the values uniques_list = [x for x in uniques_list if not pd.isna(x)] diff --git a/python/pyspark/pandas/tests/connect/series/test_parity_compute.py b/python/pyspark/pandas/tests/connect/series/test_parity_compute.py index 8876fcb139885..31916f12b4e7f 100644 --- a/python/pyspark/pandas/tests/connect/series/test_parity_compute.py +++ b/python/pyspark/pandas/tests/connect/series/test_parity_compute.py @@ -24,6 +24,10 @@ class SeriesParityComputeTests(SeriesComputeMixin, PandasOnSparkTestUtils, ReusedConnectTestCase): pass + @unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.") + def test_factorize(self): + super().test_factorize() + if __name__ == "__main__": from pyspark.pandas.tests.connect.series.test_parity_compute import * # noqa: F401 diff --git a/python/pyspark/pandas/tests/indexes/test_category.py b/python/pyspark/pandas/tests/indexes/test_category.py index ffffae828c437..6aa92b7e1e390 100644 --- a/python/pyspark/pandas/tests/indexes/test_category.py +++ b/python/pyspark/pandas/tests/indexes/test_category.py @@ -210,10 +210,6 @@ def test_astype(self): self.assert_eq(pscidx.astype(str), pcidx.astype(str)) - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43567): Enable CategoricalIndexTests.test_factorize for pandas 2.0.0.", - ) def test_factorize(self): pidx = pd.CategoricalIndex([1, 2, 3, None]) psidx = ps.from_pandas(pidx) @@ -224,8 +220,8 @@ def test_factorize(self): self.assert_eq(kcodes.tolist(), pcodes.tolist()) self.assert_eq(kuniques, puniques) - pcodes, puniques = pidx.factorize(na_sentinel=-2) - kcodes, kuniques = psidx.factorize(na_sentinel=-2) + pcodes, puniques = pidx.factorize(use_na_sentinel=-2) + kcodes, kuniques = psidx.factorize(use_na_sentinel=-2) self.assert_eq(kcodes.tolist(), pcodes.tolist()) self.assert_eq(kuniques, puniques) diff --git a/python/pyspark/pandas/tests/series/test_compute.py b/python/pyspark/pandas/tests/series/test_compute.py index 155649179e6ef..784bf29e1a25b 100644 --- a/python/pyspark/pandas/tests/series/test_compute.py +++ b/python/pyspark/pandas/tests/series/test_compute.py @@ -407,10 +407,6 @@ def test_abs(self): self.assert_eq(abs(psser), abs(pser)) self.assert_eq(np.abs(psser), np.abs(pser)) - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43550): Enable SeriesTests.test_factorize for pandas 2.0.0.", - ) def test_factorize(self): pser = pd.Series(["a", "b", "a", "b"]) psser = ps.from_pandas(pser) @@ -492,27 +488,27 @@ def test_factorize(self): pser = pd.Series(["a", "b", "a", np.nan, None]) psser = ps.from_pandas(pser) - pcodes, puniques = pser.factorize(sort=True, na_sentinel=-2) - kcodes, kuniques = psser.factorize(na_sentinel=-2) + pcodes, puniques = pser.factorize(sort=True, use_na_sentinel=-2) + kcodes, kuniques = psser.factorize(use_na_sentinel=-2) self.assert_eq(pcodes.tolist(), kcodes.to_list()) self.assert_eq(puniques, kuniques) - pcodes, puniques = pser.factorize(sort=True, na_sentinel=2) - kcodes, kuniques = psser.factorize(na_sentinel=2) + pcodes, puniques = pser.factorize(sort=True, use_na_sentinel=2) + kcodes, kuniques = psser.factorize(use_na_sentinel=2) self.assert_eq(pcodes.tolist(), kcodes.to_list()) self.assert_eq(puniques, kuniques) if not pd_below_1_1_2: - pcodes, puniques = pser.factorize(sort=True, na_sentinel=None) - kcodes, kuniques = psser.factorize(na_sentinel=None) + pcodes, puniques = pser.factorize(sort=True, use_na_sentinel=None) + kcodes, kuniques = psser.factorize(use_na_sentinel=None) self.assert_eq(pcodes.tolist(), kcodes.to_list()) # puniques is Index(['a', 'b', nan], dtype='object') self.assert_eq(ps.Index(["a", "b", None]), kuniques) psser = ps.Series([1, 2, np.nan, 4, 5]) # Arrow takes np.nan as null psser.loc[3] = np.nan # Spark takes np.nan as NaN - kcodes, kuniques = psser.factorize(na_sentinel=None) - pcodes, puniques = psser._to_pandas().factorize(sort=True, na_sentinel=None) + kcodes, kuniques = psser.factorize(use_na_sentinel=None) + pcodes, puniques = psser._to_pandas().factorize(sort=True, use_na_sentinel=None) self.assert_eq(pcodes.tolist(), kcodes.to_list()) self.assert_eq(puniques, kuniques)