diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 47c1be1ba863b..b4559dea42bb9 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -241,7 +241,10 @@ jobs: restore-keys: | ${{ matrix.java }}-${{ matrix.hadoop }}-coursier- - name: Free up disk space - run: ./dev/free_disk_space + run: | + if [ -f ./dev/free_disk_space ]; then + ./dev/free_disk_space + fi - name: Install Java ${{ matrix.java }} uses: actions/setup-java@v3 with: @@ -350,9 +353,11 @@ jobs: - >- pyspark-errors - >- - pyspark-sql, pyspark-mllib, pyspark-resource, pyspark-testing + pyspark-sql, pyspark-resource, pyspark-testing - >- - pyspark-core, pyspark-streaming, pyspark-ml + pyspark-core, pyspark-streaming + - >- + pyspark-mllib, pyspark-ml, pyspark-ml-connect - >- pyspark-pandas - >- @@ -410,6 +415,16 @@ jobs: key: pyspark-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }} restore-keys: | pyspark-coursier- + - name: Free up disk space + shell: 'script -q -e -c "bash {0}"' + run: | + if [[ "$MODULES_TO_TEST" != *"pyspark-ml"* ]]; then + # uninstall libraries dedicated for ML testing + python3.9 -m pip uninstall -y torch torchvision torcheval torchtnt tensorboard mlflow + fi + if [ -f ./dev/free_disk_space_container ]; then + ./dev/free_disk_space_container + fi - name: Install Java ${{ matrix.java }} uses: actions/setup-java@v3 with: @@ -424,6 +439,7 @@ jobs: run: | curl -s https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh > miniconda.sh bash miniconda.sh -b -p $HOME/miniconda + rm miniconda.sh # Run the tests. - name: Run tests env: ${{ fromJSON(inputs.envs) }} @@ -507,6 +523,11 @@ jobs: key: sparkr-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }} restore-keys: | sparkr-coursier- + - name: Free up disk space + run: | + if [ -f ./dev/free_disk_space_container ]; then + ./dev/free_disk_space_container + fi - name: Install Java ${{ inputs.java }} uses: actions/setup-java@v3 with: @@ -615,6 +636,11 @@ jobs: key: docs-maven-${{ hashFiles('**/pom.xml') }} restore-keys: | docs-maven- + - name: Free up disk space + run: | + if [ -f ./dev/free_disk_space_container ]; then + ./dev/free_disk_space_container + fi - name: Install Java 8 uses: actions/setup-java@v3 with: @@ -631,7 +657,22 @@ jobs: - name: Spark connect jvm client mima check if: inputs.branch != 'branch-3.3' run: ./dev/connect-jvm-client-mima-check + - name: Install Python linter dependencies for branch-3.3 + if: inputs.branch == 'branch-3.3' + run: | + # SPARK-44554: Copy from https://github.com/apache/spark/blob/073d0b60d31bf68ebacdc005f59b928a5902670f/.github/workflows/build_and_test.yml#L501-L508 + # Should delete this section after SPARK 3.3 EOL. + python3.9 -m pip install 'flake8==3.9.0' pydata_sphinx_theme 'mypy==0.920' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' numpydoc 'jinja2<3.0.0' 'black==21.12b0' + python3.9 -m pip install 'pandas-stubs==1.2.0.53' + - name: Install Python linter dependencies for branch-3.4 + if: inputs.branch == 'branch-3.4' + run: | + # SPARK-44554: Copy from https://github.com/apache/spark/blob/a05c27e85829fe742c1828507a1fd180cdc84b54/.github/workflows/build_and_test.yml#L571-L578 + # Should delete this section after SPARK 3.4 EOL. + python3.9 -m pip install 'flake8==3.9.0' pydata_sphinx_theme 'mypy==0.920' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' numpydoc 'jinja2<3.0.0' 'black==22.6.0' + python3.9 -m pip install 'pandas-stubs==1.2.0.53' ipython 'grpcio==1.48.1' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' - name: Install Python linter dependencies + if: inputs.branch != 'branch-3.3' && inputs.branch != 'branch-3.4' run: | # TODO(SPARK-32407): Sphinx 3.1+ does not correctly index nested classes. # See also https://github.com/sphinx-doc/sphinx/issues/7551. @@ -642,13 +683,16 @@ jobs: - name: Python linter run: PYTHON_EXECUTABLE=python3.9 ./dev/lint-python - name: Install dependencies for Python code generation check + if: inputs.branch != 'branch-3.3' && inputs.branch != 'branch-3.4' run: | # See more in "Installation" https://docs.buf.build/installation#tarball curl -LO https://github.com/bufbuild/buf/releases/download/v1.24.0/buf-Linux-x86_64.tar.gz mkdir -p $HOME/buf tar -xvzf buf-Linux-x86_64.tar.gz -C $HOME/buf --strip-components 1 + rm buf-Linux-x86_64.tar.gz python3.9 -m pip install 'protobuf==3.20.3' 'mypy-protobuf==3.3.0' - name: Python code generation check + if: inputs.branch != 'branch-3.3' && inputs.branch != 'branch-3.4' run: if test -f ./dev/connect-check-protos.py; then PATH=$PATH:$HOME/buf/bin PYTHON_EXECUTABLE=python3.9 ./dev/connect-check-protos.py; fi - name: Install JavaScript linter dependencies run: | @@ -1027,6 +1071,7 @@ jobs: # TODO(SPARK-44495): Resume to use the latest minikube for k8s-integration-tests. curl -LO https://storage.googleapis.com/minikube/releases/v1.30.1/minikube-linux-amd64 sudo install minikube-linux-amd64 /usr/local/bin/minikube + rm minikube-linux-amd64 # Github Action limit cpu:2, memory: 6947MB, limit to 2U6G for better resource statistic minikube start --cpus 2 --memory 6144 - name: Print K8S pods and nodes info diff --git a/.github/workflows/maven_test.yml b/.github/workflows/maven_test.yml index 48a4d6b5ff990..618ab69ba5998 100644 --- a/.github/workflows/maven_test.yml +++ b/.github/workflows/maven_test.yml @@ -57,11 +57,11 @@ jobs: - hive2.3 modules: - >- - core,repl,launcher,common#unsafe,common#kvstore,common#network-common,common#network-shuffle,common#sketch + core,launcher,common#unsafe,common#kvstore,common#network-common,common#network-shuffle,common#sketch - >- graphx,streaming,mllib-local,mllib,hadoop-cloud - >- - sql#hive-thriftserver + repl,sql#hive-thriftserver - >- connector#kafka-0-10,connector#kafka-0-10-sql,connector#kafka-0-10-token-provider,connector#spark-ganglia-lgpl,connector#protobuf,connector#avro - >- @@ -187,9 +187,9 @@ jobs: ./build/mvn $MAVEN_CLI_OPTS -pl "$TEST_MODULES" -Pyarn -Pmesos -Pkubernetes -Pvolcano -Phive -Phive-thriftserver -Phadoop-cloud -Pspark-ganglia-lgpl -Djava.version=${JAVA_VERSION/-ea} -Dtest.exclude.tags="$EXCLUDED_TAGS" test -fae elif [[ "$MODULES_TO_TEST" == "connect" ]]; then ./build/mvn $MAVEN_CLI_OPTS -Djava.version=${JAVA_VERSION/-ea} -pl connector/connect/client/jvm,connector/connect/common,connector/connect/server test -fae - elif [[ "$MODULES_TO_TEST" == "sql#hive-thriftserver" ]]; then + elif [[ "$MODULES_TO_TEST" == *"sql#hive-thriftserver"* ]]; then # To avoid a compilation loop, for the `sql/hive-thriftserver` module, run `clean install` instead - ./build/mvn $MAVEN_CLI_OPTS -pl sql/hive-thriftserver -Phive -Phive-thriftserver -Djava.version=${JAVA_VERSION/-ea} clean install -fae + ./build/mvn $MAVEN_CLI_OPTS -pl "$TEST_MODULES" -Pyarn -Pmesos -Pkubernetes -Pvolcano -Phive -Phive-thriftserver -Phadoop-cloud -Pspark-ganglia-lgpl -Djava.version=${JAVA_VERSION/-ea} clean install -fae else ./build/mvn $MAVEN_CLI_OPTS -pl "$TEST_MODULES" -Pyarn -Pmesos -Pkubernetes -Pvolcano -Phive -Phive-thriftserver -Pspark-ganglia-lgpl -Phadoop-cloud -Djava.version=${JAVA_VERSION/-ea} test -fae fi diff --git a/.gitignore b/.gitignore index 11141961bf805..064b502175b79 100644 --- a/.gitignore +++ b/.gitignore @@ -117,6 +117,6 @@ spark-warehouse/ node_modules # For Antlr -sql/catalyst/gen/ -sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.tokens -sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/gen/ +sql/api/gen/ +sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.tokens +sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/gen/ diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 063505228340e..0ea1eed35e463 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -809,12 +809,12 @@ "subClass" : { "BOTH_POSITIONAL_AND_NAMED" : { "message" : [ - "A positional argument and named argument both referred to the same parameter." + "A positional argument and named argument both referred to the same parameter. Please remove the named argument referring to this parameter." ] }, "DOUBLE_NAMED_ARGUMENT_REFERENCE" : { "message" : [ - "More than one named argument referred to the same parameter." + "More than one named argument referred to the same parameter. Please assign a value only once." ] } }, @@ -831,6 +831,11 @@ "Not found an encoder of the type to Spark SQL internal representation. Consider to change the input type to one of supported at '/sql-ref-datatypes.html'." ] }, + "ERROR_READING_AVRO_UNKNOWN_FINGERPRINT" : { + "message" : [ + "Error reading avro data -- encountered an unknown fingerprint: , not sure what schema to use. This could happen if you registered additional schemas after starting your spark context." + ] + }, "EVENT_TIME_IS_NOT_ON_TIMESTAMP_TYPE" : { "message" : [ "The event time has the invalid type , but expected \"TIMESTAMP\"." @@ -864,6 +869,11 @@ ], "sqlState" : "22018" }, + "FAILED_REGISTER_CLASS_WITH_KRYO" : { + "message" : [ + "Failed to register classes with Kryo." + ] + }, "FAILED_RENAME_PATH" : { "message" : [ "Failed to rename to as destination already exists." @@ -1564,6 +1574,12 @@ ], "sqlState" : "22032" }, + "INVALID_KRYO_SERIALIZER_BUFFER_SIZE" : { + "message" : [ + "The value of the config \"\" must be less than 2048 MiB, but got MiB." + ], + "sqlState" : "F0000" + }, "INVALID_LAMBDA_FUNCTION_CALL" : { "message" : [ "Invalid lambda function call." @@ -2006,6 +2022,11 @@ "The join condition has the invalid type , expected \"BOOLEAN\"." ] }, + "KRYO_BUFFER_OVERFLOW" : { + "message" : [ + "Kryo serialization failed: . To avoid this, increase \"\" value." + ] + }, "LOAD_DATA_PATH_NOT_EXISTS" : { "message" : [ "LOAD DATA input path does not exist: ." @@ -2043,6 +2064,11 @@ "Parsing JSON arrays as structs is forbidden." ] }, + "CANNOT_PARSE_STRING_AS_DATATYPE" : { + "message" : [ + "Cannot parse the value of the field as target spark data type from the input type ." + ] + }, "WITHOUT_SUGGESTION" : { "message" : [ "" @@ -2446,7 +2472,7 @@ }, "REQUIRED_PARAMETER_NOT_FOUND" : { "message" : [ - "Cannot invoke function because the parameter named is required, but the function call did not supply a value. Please update the function call to supply an argument value (either positionally or by name) and retry the query again." + "Cannot invoke function because the parameter named is required, but the function call did not supply a value. Please update the function call to supply an argument value (either positionally at index or by name) and retry the query again." ], "sqlState" : "4274K" }, @@ -2471,6 +2497,12 @@ ], "sqlState" : "42883" }, + "RULE_ID_NOT_FOUND" : { + "message" : [ + "Not found an id for the rule name \"\". Please modify RuleIdCollection.scala if you are adding a new rule." + ], + "sqlState" : "22023" + }, "SCALAR_SUBQUERY_IS_IN_GROUP_BY_OR_AGGREGATE_FUNCTION" : { "message" : [ "The correlated scalar subquery '' is neither present in GROUP BY, nor in an aggregate function. Add it to GROUP BY using ordinal position or wrap it in `first()` (or `first_value`) if you don't care which value you get." @@ -2647,7 +2679,7 @@ }, "UNEXPECTED_POSITIONAL_ARGUMENT" : { "message" : [ - "Cannot invoke function because it contains positional argument(s) following named argument(s); please rearrange them so the positional arguments come first and then retry the query again." + "Cannot invoke function because it contains positional argument(s) following the named argument assigned to ; please rearrange them so the positional arguments come first and then retry the query again." ], "sqlState" : "4274K" }, @@ -5312,11 +5344,6 @@ "Exception when registering StreamingQueryListener." ] }, - "_LEGACY_ERROR_TEMP_2133" : { - "message" : [ - "Cannot parse field name , field value , [] as target spark data type []." - ] - }, "_LEGACY_ERROR_TEMP_2134" : { "message" : [ "Cannot parse field value for pattern as target spark data type []." @@ -5489,11 +5516,6 @@ "." ] }, - "_LEGACY_ERROR_TEMP_2175" : { - "message" : [ - "Rule id not found for . Please modify RuleIdCollection.scala if you are adding a new rule." - ] - }, "_LEGACY_ERROR_TEMP_2176" : { "message" : [ "Cannot create array with elements of data due to exceeding the limit elements for ArrayData. " diff --git a/connector/connect/client/jvm/src/main/java/org/apache/spark/sql/streaming/Trigger.java b/connector/connect/client/jvm/src/main/java/org/apache/spark/sql/streaming/Trigger.java deleted file mode 100644 index 27ffe67d9909c..0000000000000 --- a/connector/connect/client/jvm/src/main/java/org/apache/spark/sql/streaming/Trigger.java +++ /dev/null @@ -1,180 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.streaming; - -import java.util.concurrent.TimeUnit; - -import scala.concurrent.duration.Duration; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.execution.streaming.AvailableNowTrigger$; -import org.apache.spark.sql.execution.streaming.ContinuousTrigger; -import org.apache.spark.sql.execution.streaming.OneTimeTrigger$; -import org.apache.spark.sql.execution.streaming.ProcessingTimeTrigger; - -/** - * Policy used to indicate how often results should be produced by a [[StreamingQuery]]. - * - * @since 3.5.0 - */ -@Evolving -public class Trigger { - // This is a copy of the same class in sql/core/.../streaming/Trigger.java - - /** - * A trigger policy that runs a query periodically based on an interval in processing time. - * If `interval` is 0, the query will run as fast as possible. - * - * @since 3.5.0 - */ - public static Trigger ProcessingTime(long intervalMs) { - return ProcessingTimeTrigger.create(intervalMs, TimeUnit.MILLISECONDS); - } - - /** - * (Java-friendly) - * A trigger policy that runs a query periodically based on an interval in processing time. - * If `interval` is 0, the query will run as fast as possible. - * - * {{{ - * import java.util.concurrent.TimeUnit - * df.writeStream().trigger(Trigger.ProcessingTime(10, TimeUnit.SECONDS)) - * }}} - * - * @since 3.5.0 - */ - public static Trigger ProcessingTime(long interval, TimeUnit timeUnit) { - return ProcessingTimeTrigger.create(interval, timeUnit); - } - - /** - * (Scala-friendly) - * A trigger policy that runs a query periodically based on an interval in processing time. - * If `duration` is 0, the query will run as fast as possible. - * - * {{{ - * import scala.concurrent.duration._ - * df.writeStream.trigger(Trigger.ProcessingTime(10.seconds)) - * }}} - * @since 3.5.0 - */ - public static Trigger ProcessingTime(Duration interval) { - return ProcessingTimeTrigger.apply(interval); - } - - /** - * A trigger policy that runs a query periodically based on an interval in processing time. - * If `interval` is effectively 0, the query will run as fast as possible. - * - * {{{ - * df.writeStream.trigger(Trigger.ProcessingTime("10 seconds")) - * }}} - * @since 3.5.0 - */ - public static Trigger ProcessingTime(String interval) { - return ProcessingTimeTrigger.apply(interval); - } - - /** - * A trigger that processes all available data in a single batch then terminates the query. - * - * @since 3.5.0 - * @deprecated This is deprecated as of Spark 3.4.0. Use {@link #AvailableNow()} to leverage - * better guarantee of processing, fine-grained scale of batches, and better gradual - * processing of watermark advancement including no-data batch. - * See the NOTES in {@link #AvailableNow()} for details. - */ - @Deprecated - public static Trigger Once() { - return OneTimeTrigger$.MODULE$; - } - - /** - * A trigger that processes all available data at the start of the query in one or multiple - * batches, then terminates the query. - * - * Users are encouraged to set the source options to control the size of the batch as similar as - * controlling the size of the batch in {@link #ProcessingTime(long)} trigger. - * - * NOTES: - * - This trigger provides a strong guarantee of processing: regardless of how many batches were - * left over in previous run, it ensures all available data at the time of execution gets - * processed before termination. All uncommitted batches will be processed first. - * - Watermark gets advanced per each batch, and no-data batch gets executed before termination - * if the last batch advances the watermark. This helps to maintain smaller and predictable - * state size and smaller latency on the output of stateful operators. - * - * @since 3.5.0 - */ - public static Trigger AvailableNow() { - return AvailableNowTrigger$.MODULE$; - } - - /** - * A trigger that continuously processes streaming data, asynchronously checkpointing at - * the specified interval. - * - * @since 3.5.0 - */ - public static Trigger Continuous(long intervalMs) { - return ContinuousTrigger.apply(intervalMs); - } - - /** - * A trigger that continuously processes streaming data, asynchronously checkpointing at - * the specified interval. - * - * {{{ - * import java.util.concurrent.TimeUnit - * df.writeStream.trigger(Trigger.Continuous(10, TimeUnit.SECONDS)) - * }}} - * - * @since 3.5.0 - */ - public static Trigger Continuous(long interval, TimeUnit timeUnit) { - return ContinuousTrigger.create(interval, timeUnit); - } - - /** - * (Scala-friendly) - * A trigger that continuously processes streaming data, asynchronously checkpointing at - * the specified interval. - * - * {{{ - * import scala.concurrent.duration._ - * df.writeStream.trigger(Trigger.Continuous(10.seconds)) - * }}} - * @since 3.5.0 - */ - public static Trigger Continuous(Duration interval) { - return ContinuousTrigger.apply(interval); - } - - /** - * A trigger that continuously processes streaming data, asynchronously checkpointing at - * the specified interval. - * - * {{{ - * df.writeStream.trigger(Trigger.Continuous("10 seconds")) - * }}} - * @since 3.5.0 - */ - public static Trigger Continuous(String interval) { - return ContinuousTrigger.apply(interval); - } -} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 0f7b376955c96..8a7dce3987a44 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2832,7 +2832,7 @@ class Dataset[T] private[sql] ( /** * Returns an iterator that contains all rows in this Dataset. * - * The returned iterator implements [[AutoCloseable]]. For memory management it is better to + * The returned iterator implements [[AutoCloseable]]. For resource management it is better to * close it once you are done. If you don't close it, it and the underlying data will be cleaned * up once the iterator is garbage collected. * @@ -2840,7 +2840,7 @@ class Dataset[T] private[sql] ( * @since 3.4.0 */ def toLocalIterator(): java.util.Iterator[T] = { - collectResult().destructiveIterator + collectResult().destructiveIterator.asJava } /** diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala index 3f2f7ec96d4f5..74f0133803137 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder => RowEncoderFactory} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ +import org.apache.spark.sql.types.StructType /** * Methods for creating an [[Encoder]]. @@ -168,6 +169,13 @@ object Encoders { */ def bean[T](beanClass: Class[T]): Encoder[T] = JavaTypeInference.encoderFor(beanClass) + /** + * Creates a [[Row]] encoder for schema `schema`. + * + * @since 3.5.0 + */ + def row(schema: StructType): Encoder[Row] = RowEncoderFactory.encoderFor(schema) + private def tupleEncoder[T](encoders: Encoder[_]*): Encoder[T] = { ProductEncoder.tuple(encoders.asInstanceOf[Seq[AgnosticEncoder[_]]]).asInstanceOf[Encoder[T]] } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 59f3f3526ab2f..7367ed153f7db 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import java.io.Closeable import java.net.URI import java.util.concurrent.TimeUnit._ -import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.atomic.{AtomicLong, AtomicReference} import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag @@ -252,10 +252,8 @@ class SparkSession private[sql] ( .setSql(sqlText) .addAllPosArgs(args.map(toLiteralProto).toIterable.asJava))) val plan = proto.Plan.newBuilder().setCommand(cmd) - val responseSeq = client.execute(plan.build()).asScala.toSeq - - // sequence is a lazy stream, force materialize it to make sure it is consumed. - responseSeq.foreach(_ => ()) + // .toBuffer forces that the iterator is consumed and closed + val responseSeq = client.execute(plan.build()).toBuffer.toSeq val response = responseSeq .find(_.hasSqlCommandResult) @@ -311,10 +309,8 @@ class SparkSession private[sql] ( .setSql(sqlText) .putAllArgs(args.asScala.mapValues(toLiteralProto).toMap.asJava))) val plan = proto.Plan.newBuilder().setCommand(cmd) - val responseSeq = client.execute(plan.build()).asScala.toSeq - - // sequence is a lazy stream, force materialize it to make sure it is consumed. - responseSeq.foreach(_ => ()) + // .toBuffer forces that the iterator is consumed and closed + val responseSeq = client.execute(plan.build()).toBuffer.toSeq val response = responseSeq .find(_.hasSqlCommandResult) @@ -548,15 +544,14 @@ class SparkSession private[sql] ( f(builder) builder.getCommonBuilder.setPlanId(planIdGenerator.getAndIncrement()) val plan = proto.Plan.newBuilder().setRoot(builder).build() - client.execute(plan).asScala.foreach(_ => ()) + // .toBuffer forces that the iterator is consumed and closed + client.execute(plan).toBuffer } private[sql] def execute(command: proto.Command): Seq[ExecutePlanResponse] = { val plan = proto.Plan.newBuilder().setCommand(command).build() - val seq = client.execute(plan).asScala.toSeq - // sequence is a lazy stream, force materialize it to make sure it is consumed. - seq.foreach(_ => ()) - seq + // .toBuffer forces that the iterator is consumed and closed + client.execute(plan).toBuffer.toSeq } private[sql] def registerUdf(udf: proto.CommonInlineUserDefinedFunction): Unit = { @@ -735,6 +730,23 @@ object SparkSession extends Logging { override def load(c: Configuration): SparkSession = create(c) }) + /** The active SparkSession for the current thread. */ + private val activeThreadSession = new InheritableThreadLocal[SparkSession] + + /** Reference to the root SparkSession. */ + private val defaultSession = new AtomicReference[SparkSession] + + /** + * Set the (global) default [[SparkSession]], and (thread-local) active [[SparkSession]] when + * they are not set yet. + */ + private def setDefaultAndActiveSession(session: SparkSession): Unit = { + defaultSession.compareAndSet(null, session) + if (getActiveSession.isEmpty) { + setActiveSession(session) + } + } + /** * Create a new [[SparkSession]] based on the connect client [[Configuration]]. */ @@ -747,8 +759,17 @@ object SparkSession extends Logging { */ private[sql] def onSessionClose(session: SparkSession): Unit = { sessions.invalidate(session.client.configuration) + defaultSession.compareAndSet(session, null) + if (getActiveSession.contains(session)) { + clearActiveSession() + } } + /** + * Creates a [[SparkSession.Builder]] for constructing a [[SparkSession]]. + * + * @since 3.4.0 + */ def builder(): Builder = new Builder() private[sql] lazy val cleaner = { @@ -804,10 +825,15 @@ object SparkSession extends Logging { * * This will always return a newly created session. * + * This method will update the default and/or active session if they are not set. + * * @since 3.5.0 */ def create(): SparkSession = { - tryCreateSessionFromClient().getOrElse(SparkSession.this.create(builder.configuration)) + val session = tryCreateSessionFromClient() + .getOrElse(SparkSession.this.create(builder.configuration)) + setDefaultAndActiveSession(session) + session } /** @@ -816,30 +842,79 @@ object SparkSession extends Logging { * If a session exist with the same configuration that is returned instead of creating a new * session. * + * This method will update the default and/or active session if they are not set. + * * @since 3.5.0 */ def getOrCreate(): SparkSession = { - tryCreateSessionFromClient().getOrElse(sessions.get(builder.configuration)) + val session = tryCreateSessionFromClient() + .getOrElse(sessions.get(builder.configuration)) + setDefaultAndActiveSession(session) + session } } - def getActiveSession: Option[SparkSession] = { - throw new UnsupportedOperationException("getActiveSession is not supported") + /** + * Returns the default SparkSession. + * + * @since 3.5.0 + */ + def getDefaultSession: Option[SparkSession] = Option(defaultSession.get()) + + /** + * Sets the default SparkSession. + * + * @since 3.5.0 + */ + def setDefaultSession(session: SparkSession): Unit = { + defaultSession.set(session) } - def getDefaultSession: Option[SparkSession] = { - throw new UnsupportedOperationException("getDefaultSession is not supported") + /** + * Clears the default SparkSession. + * + * @since 3.5.0 + */ + def clearDefaultSession(): Unit = { + defaultSession.set(null) } + /** + * Returns the active SparkSession for the current thread. + * + * @since 3.5.0 + */ + def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get()) + + /** + * Changes the SparkSession that will be returned in this thread and its children when + * SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives + * an isolated SparkSession. + * + * @since 3.5.0 + */ def setActiveSession(session: SparkSession): Unit = { - throw new UnsupportedOperationException("setActiveSession is not supported") + activeThreadSession.set(session) } + /** + * Clears the active SparkSession for current thread. + * + * @since 3.5.0 + */ def clearActiveSession(): Unit = { - throw new UnsupportedOperationException("clearActiveSession is not supported") + activeThreadSession.remove() } + /** + * Returns the currently active SparkSession, otherwise the default one. If there is no default + * SparkSession, throws an exception. + * + * @since 3.5.0 + */ def active: SparkSession = { - throw new UnsupportedOperationException("active is not supported") + getActiveSession + .orElse(getDefaultSession) + .getOrElse(throw new IllegalStateException("No active or default Spark session found")) } } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala new file mode 100644 index 0000000000000..891e50ed6e7bd --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.client + +private[sql] trait CloseableIterator[E] extends Iterator[E] with AutoCloseable { self => + def asJava: java.util.Iterator[E] = new java.util.Iterator[E] with AutoCloseable { + override def next() = self.next() + + override def hasNext() = self.hasNext + + override def close() = self.close() + } +} + +private[sql] object CloseableIterator { + + /** + * Wrap iterator to get CloseeableIterator, if it wasn't closeable already. + */ + def apply[T](iterator: Iterator[T]): CloseableIterator[T] = iterator match { + case closeable: CloseableIterator[T] => closeable + case _ => + new CloseableIterator[T] { + override def next(): T = iterator.next() + + override def hasNext(): Boolean = iterator.hasNext + + override def close() = { /* empty */ } + } + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala index bb20901eade17..73ff01e223f29 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.connect.client +import scala.collection.JavaConverters._ + import io.grpc.ManagedChannel import org.apache.spark.connect.proto._ @@ -27,15 +29,17 @@ private[client] class CustomSparkConnectBlockingStub( private val stub = SparkConnectServiceGrpc.newBlockingStub(channel) private val retryHandler = new GrpcRetryHandler(retryPolicy) - def executePlan(request: ExecutePlanRequest): java.util.Iterator[ExecutePlanResponse] = { + def executePlan(request: ExecutePlanRequest): CloseableIterator[ExecutePlanResponse] = { GrpcExceptionConverter.convert { GrpcExceptionConverter.convertIterator[ExecutePlanResponse]( - retryHandler.RetryIterator(request, stub.executePlan)) + retryHandler.RetryIterator[ExecutePlanRequest, ExecutePlanResponse]( + request, + r => CloseableIterator(stub.executePlan(r).asScala))) } } def executePlanReattachable( - request: ExecutePlanRequest): java.util.Iterator[ExecutePlanResponse] = { + request: ExecutePlanRequest): CloseableIterator[ExecutePlanResponse] = { GrpcExceptionConverter.convert { GrpcExceptionConverter.convertIterator[ExecutePlanResponse]( // Don't use retryHandler - own retry handling is inside. diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala index 008b3c3dd5c71..d412d9b577064 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala @@ -20,7 +20,8 @@ import java.util.UUID import scala.util.control.NonFatal -import io.grpc.ManagedChannel +import io.grpc.{ManagedChannel, StatusRuntimeException} +import io.grpc.protobuf.StatusProto import io.grpc.stub.StreamObserver import org.apache.spark.connect.proto @@ -38,21 +39,18 @@ import org.apache.spark.internal.Logging * Initial iterator is the result of an ExecutePlan on the request, but it can be reattached with * ReattachExecute request. ReattachExecute request is provided the responseId of last returned * ExecutePlanResponse on the iterator to return a new iterator from server that continues after - * that. + * that. If the initial ExecutePlan did not even reach the server, and hence reattach fails with + * INVALID_HANDLE.OPERATION_NOT_FOUND, we attempt to retry ExecutePlan. * * In reattachable execute the server does buffer some responses in case the client needs to * backtrack. To let server release this buffer sooner, this iterator asynchronously sends * ReleaseExecute RPCs that instruct the server to release responses that it already processed. - * - * Note: If the initial ExecutePlan did not even reach the server and execution didn't start, the - * ReattachExecute can still fail with INVALID_HANDLE.OPERATION_NOT_FOUND, failing the whole - * operation. */ class ExecutePlanResponseReattachableIterator( request: proto.ExecutePlanRequest, channel: ManagedChannel, retryPolicy: GrpcRetryHandler.RetryPolicy) - extends java.util.Iterator[proto.ExecutePlanResponse] + extends CloseableIterator[proto.ExecutePlanResponse] with Logging { val operationId = if (request.hasOperationId) { @@ -92,8 +90,8 @@ class ExecutePlanResponseReattachableIterator( // Initial iterator comes from ExecutePlan request. // Note: This is not retried, because no error would ever be thrown here, and GRPC will only - // throw error on first iterator.hasNext() or iterator.next() - private var iterator: java.util.Iterator[proto.ExecutePlanResponse] = + // throw error on first iter.hasNext() or iter.next() + private var iter: java.util.Iterator[proto.ExecutePlanResponse] = rawBlockingStub.executePlan(initialRequest) override def next(): proto.ExecutePlanResponse = synchronized { @@ -102,28 +100,33 @@ class ExecutePlanResponseReattachableIterator( throw new java.util.NoSuchElementException() } - // Get next response, possibly triggering reattach in case of stream error. - var firstTry = true - val ret = retry { - if (firstTry) { - // on first try, we use the existing iterator. - firstTry = false - } else { - // on retry, the iterator is borked, so we need a new one - iterator = rawBlockingStub.reattachExecute(createReattachExecuteRequest()) + try { + // Get next response, possibly triggering reattach in case of stream error. + var firstTry = true + val ret = retry { + if (firstTry) { + // on first try, we use the existing iter. + firstTry = false + } else { + // on retry, the iter is borked, so we need a new one + iter = rawBlockingStub.reattachExecute(createReattachExecuteRequest()) + } + callIter(_.next()) } - iterator.next() - } - // Record last returned response, to know where to restart in case of reattach. - lastReturnedResponseId = Some(ret.getResponseId) - if (ret.hasResultComplete) { - resultComplete = true - releaseExecute(None) // release all - } else { - releaseExecute(lastReturnedResponseId) // release until this response + // Record last returned response, to know where to restart in case of reattach. + lastReturnedResponseId = Some(ret.getResponseId) + if (ret.hasResultComplete) { + releaseAll() + } else { + releaseUntil(lastReturnedResponseId.get) + } + ret + } catch { + case NonFatal(ex) => + releaseAll() // ReleaseExecute on server after error. + throw ex } - ret } override def hasNext(): Boolean = synchronized { @@ -132,47 +135,95 @@ class ExecutePlanResponseReattachableIterator( return false } var firstTry = true - retry { - if (firstTry) { - // on first try, we use the existing iterator. - firstTry = false - } else { - // on retry, the iterator is borked, so we need a new one - iterator = rawBlockingStub.reattachExecute(createReattachExecuteRequest()) - } - var hasNext = iterator.hasNext() - // Graceful reattach: - // If iterator ended, but there was no ResultComplete, it means that there is more, - // and we need to reattach. - if (!hasNext && !resultComplete) { - do { - iterator = rawBlockingStub.reattachExecute(createReattachExecuteRequest()) - assert(!resultComplete) // shouldn't change... - hasNext = iterator.hasNext() - // It's possible that the new iterator will be empty, so we need to loop to get another. - // Eventually, there will be a non empty iterator, because there's always a ResultComplete - // at the end of the stream. - } while (!hasNext) + try { + retry { + if (firstTry) { + // on first try, we use the existing iter. + firstTry = false + } else { + // on retry, the iter is borked, so we need a new one + iter = rawBlockingStub.reattachExecute(createReattachExecuteRequest()) + } + var hasNext = callIter(_.hasNext()) + // Graceful reattach: + // If iter ended, but there was no ResultComplete, it means that there is more, + // and we need to reattach. + if (!hasNext && !resultComplete) { + do { + iter = rawBlockingStub.reattachExecute(createReattachExecuteRequest()) + assert(!resultComplete) // shouldn't change... + hasNext = callIter(_.hasNext()) + // It's possible that the new iter will be empty, so we need to loop to get another. + // Eventually, there will be a non empty iter, because there is always a + // ResultComplete inserted by the server at the end of the stream. + } while (!hasNext) + } + hasNext } - hasNext + } catch { + case NonFatal(ex) => + releaseAll() // ReleaseExecute on server after error. + throw ex } } + override def close(): Unit = { + releaseAll() + } + /** - * Inform the server to release the execution. + * Inform the server to release the buffered execution results until and including given result. * * This will send an asynchronous RPC which will not block this iterator, the iterator can * continue to be consumed. + */ + private def releaseUntil(untilResponseId: String): Unit = { + if (!resultComplete) { + val request = createReleaseExecuteRequest(Some(untilResponseId)) + rawAsyncStub.releaseExecute(request, createRetryingReleaseExecuteResponseObserer(request)) + } + } + + /** + * Inform the server to release the execution, either because all results were consumed, or the + * execution finished with error and the error was received. * - * Release with untilResponseId informs the server that the iterator has been consumed until and - * including response with that responseId, and these responses can be freed. + * This will send an asynchronous RPC which will not block this. The client continues executing, + * and if the release fails, server is equipped to deal with abandoned executions. + */ + private def releaseAll(): Unit = { + if (!resultComplete) { + val request = createReleaseExecuteRequest(None) + rawAsyncStub.releaseExecute(request, createRetryingReleaseExecuteResponseObserer(request)) + resultComplete = true + } + } + + /** + * Call next() or hasNext() on the iterator. If this fails with this operationId not existing on + * the server, this means that the initial ExecutePlan request didn't even reach the server. In + * that case, attempt to start again with ExecutePlan. * - * Release with None means that the responses have been completely consumed and informs the - * server that the completed execution can be completely freed. + * Called inside retry block, so retryable failure will get handled upstream. */ - private def releaseExecute(untilResponseId: Option[String]): Unit = { - val request = createReleaseExecuteRequest(untilResponseId) - rawAsyncStub.releaseExecute(request, createRetryingReleaseExecuteResponseObserer(request)) + private def callIter[V](iterFun: java.util.Iterator[proto.ExecutePlanResponse] => V) = { + try { + iterFun(iter) + } catch { + case ex: StatusRuntimeException + if StatusProto + .fromThrowable(ex) + .getMessage + .contains("INVALID_HANDLE.OPERATION_NOT_FOUND") => + if (lastReturnedResponseId.isDefined) { + throw new IllegalStateException( + "OPERATION_NOT_FOUND on the server but responses were already received from it.", + ex) + } + // Try a new ExecutePlan, and throw upstream for retry. + iter = rawBlockingStub.executePlan(initialRequest) + throw new GrpcRetryHandler.RetryException + } } /** diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala index 1a42ec821d84f..64d1e5c488ab4 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala @@ -16,23 +16,31 @@ */ package org.apache.spark.sql.connect.client +import scala.jdk.CollectionConverters._ +import scala.reflect.ClassTag + +import com.google.rpc.ErrorInfo import io.grpc.StatusRuntimeException import io.grpc.protobuf.StatusProto -import org.apache.spark.{SparkException, SparkThrowable} +import org.apache.spark.SparkException +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.util.JsonUtils -private[client] object GrpcExceptionConverter { +private[client] object GrpcExceptionConverter extends JsonUtils { def convert[T](f: => T): T = { try { f } catch { case e: StatusRuntimeException => - throw toSparkThrowable(e) + throw toThrowable(e) } } - def convertIterator[T](iter: java.util.Iterator[T]): java.util.Iterator[T] = { - new java.util.Iterator[T] { + def convertIterator[T](iter: CloseableIterator[T]): CloseableIterator[T] = { + new CloseableIterator[T] { override def hasNext: Boolean = { convert { iter.hasNext @@ -44,14 +52,50 @@ private[client] object GrpcExceptionConverter { iter.next() } } + + override def close(): Unit = { + convert { + iter.close() + } + } } } - private def toSparkThrowable(ex: StatusRuntimeException): SparkThrowable with Throwable = { - val status = StatusProto.fromThrowable(ex) - // TODO: Add finer grained error conversion - new SparkException(status.getMessage, ex.getCause) + private def errorConstructor[T <: Throwable: ClassTag]( + throwableCtr: (String, Throwable) => T): (String, (String, Throwable) => Throwable) = { + val className = implicitly[reflect.ClassTag[T]].runtimeClass.getName + (className, throwableCtr) } -} + private val errorFactory = Map( + errorConstructor((message, _) => new ParseException(None, message, Origin(), Origin())), + errorConstructor((message, cause) => new AnalysisException(message, cause = Option(cause)))) + + private def errorInfoToThrowable(info: ErrorInfo, message: String): Option[Throwable] = { + val classes = + mapper.readValue(info.getMetadataOrDefault("classes", "[]"), classOf[Array[String]]) + classes + .find(errorFactory.contains) + .map { cls => + val constructor = errorFactory.get(cls).get + constructor(message, null) + } + } + + private def toThrowable(ex: StatusRuntimeException): Throwable = { + val status = StatusProto.fromThrowable(ex) + + val fallbackEx = new SparkException(status.getMessage, ex.getCause) + + val errorInfoOpt = status.getDetailsList.asScala + .find(_.is(classOf[ErrorInfo])) + + if (errorInfoOpt.isEmpty) { + return fallbackEx + } + + errorInfoToThrowable(errorInfoOpt.get.unpack(classOf[ErrorInfo]), status.getMessage) + .getOrElse(fallbackEx) + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala index ef446399f1674..6dad5b4b3a9b4 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala @@ -45,13 +45,13 @@ private[client] class GrpcRetryHandler(private val retryPolicy: GrpcRetryHandler * @tparam U * The type of the response. */ - class RetryIterator[T, U](request: T, call: T => java.util.Iterator[U]) - extends java.util.Iterator[U] { + class RetryIterator[T, U](request: T, call: T => CloseableIterator[U]) + extends CloseableIterator[U] { private var opened = false // we only retry if it fails on first call when using the iterator - private var iterator = call(request) + private var iter = call(request) - private def retryIter[V](f: java.util.Iterator[U] => V) = { + private def retryIter[V](f: Iterator[U] => V) = { if (!opened) { opened = true var firstTry = true @@ -61,26 +61,30 @@ private[client] class GrpcRetryHandler(private val retryPolicy: GrpcRetryHandler firstTry = false } else { // on retry, we need to call the RPC again. - iterator = call(request) + iter = call(request) } - f(iterator) + f(iter) } } else { - f(iterator) + f(iter) } } override def next: U = { - retryIter(_.next()) + retryIter(_.next) } override def hasNext: Boolean = { - retryIter(_.hasNext()) + retryIter(_.hasNext) + } + + override def close(): Unit = { + iter.close() } } object RetryIterator { - def apply[T, U](request: T, call: T => java.util.Iterator[U]): RetryIterator[T, U] = + def apply[T, U](request: T, call: T => CloseableIterator[U]): RetryIterator[T, U] = new RetryIterator(request, call) } @@ -164,7 +168,9 @@ private[client] object GrpcRetryHandler extends Logging { try { return fn } catch { - case NonFatal(e) if retryPolicy.canRetry(e) && currentRetryNum < retryPolicy.maxRetries => + case NonFatal(e) + if (retryPolicy.canRetry(e) || e.isInstanceOf[RetryException]) + && currentRetryNum < retryPolicy.maxRetries => logWarning( s"Non fatal error during RPC execution: $e, " + s"retrying (currentRetryNum=$currentRetryNum)") @@ -209,4 +215,10 @@ private[client] object GrpcRetryHandler extends Logging { maxBackoff: FiniteDuration = FiniteDuration(1, "min"), backoffMultiplier: Double = 4.0, canRetry: Throwable => Boolean = retryException) {} + + /** + * An exception that can be thrown upstream when inside retry and which will be retryable + * regardless of policy. + */ + class RetryException extends Throwable } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index 3d20be88888c3..a028df536cf88 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -75,10 +75,11 @@ private[sql] class SparkConnectClient( /** * Execute the plan and return response iterator. * - * It returns an open iterator. The caller needs to ensure that this iterator is fully consumed, - * otherwise resources held by a re-attachable query may be left dangling until server timeout. + * It returns CloseableIterator. For resource management it is better to close it once you are + * done. If you don't close it, it and the underlying data will be cleaned up once the iterator + * is garbage collected. */ - def execute(plan: proto.Plan): java.util.Iterator[proto.ExecutePlanResponse] = { + def execute(plan: proto.Plan): CloseableIterator[proto.ExecutePlanResponse] = { artifactManager.uploadAllClassFileArtifacts() val request = proto.ExecutePlanRequest .newBuilder() diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala index 93c32aa2954a3..609e84779fbfc 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala @@ -27,14 +27,14 @@ import org.apache.arrow.vector.types.pojo import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder, UnboundRowEncoder} -import org.apache.spark.sql.connect.client.arrow.{AbstractMessageIterator, ArrowDeserializingIterator, CloseableIterator, ConcatenatingArrowStreamReader, MessageIterator} +import org.apache.spark.sql.connect.client.arrow.{AbstractMessageIterator, ArrowDeserializingIterator, ConcatenatingArrowStreamReader, MessageIterator} import org.apache.spark.sql.connect.client.util.Cleanable import org.apache.spark.sql.connect.common.DataTypeProtoConverter import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.ArrowUtils private[sql] class SparkResult[T]( - responses: java.util.Iterator[proto.ExecutePlanResponse], + responses: CloseableIterator[proto.ExecutePlanResponse], allocator: BufferAllocator, encoder: AgnosticEncoder[T], timeZoneId: String) @@ -198,22 +198,22 @@ private[sql] class SparkResult[T]( /** * Returns an iterator over the contents of the result. */ - def iterator: java.util.Iterator[T] with AutoCloseable = + def iterator: CloseableIterator[T] = buildIterator(destructive = false) /** * Returns an destructive iterator over the contents of the result. */ - def destructiveIterator: java.util.Iterator[T] with AutoCloseable = + def destructiveIterator: CloseableIterator[T] = buildIterator(destructive = true) - private def buildIterator(destructive: Boolean): java.util.Iterator[T] with AutoCloseable = { - new java.util.Iterator[T] with AutoCloseable { - private[this] var iterator: CloseableIterator[T] = _ + private def buildIterator(destructive: Boolean): CloseableIterator[T] = { + new CloseableIterator[T] { + private[this] var iter: CloseableIterator[T] = _ private def initialize(): Unit = { - if (iterator == null) { - iterator = new ArrowDeserializingIterator( + if (iter == null) { + iter = new ArrowDeserializingIterator( createEncoder(encoder, schema), new ConcatenatingArrowStreamReader( allocator, @@ -225,17 +225,17 @@ private[sql] class SparkResult[T]( override def hasNext: Boolean = { initialize() - iterator.hasNext + iter.hasNext } override def next(): T = { initialize() - iterator.next() + iter.next() } override def close(): Unit = { - if (iterator != null) { - iterator.close() + if (iter != null) { + iter.close() } } } @@ -246,7 +246,7 @@ private[sql] class SparkResult[T]( */ override def close(): Unit = cleaner.close() - override val cleaner: AutoCloseable = new SparkResultCloseable(resultMap) + override val cleaner: AutoCloseable = new SparkResultCloseable(resultMap, responses) private class ResultMessageIterator(destructive: Boolean) extends AbstractMessageIterator { private[this] var totalBytesRead = 0L @@ -296,7 +296,12 @@ private[sql] class SparkResult[T]( } } -private[client] class SparkResultCloseable(resultMap: mutable.Map[Int, (Long, Seq[ArrowMessage])]) +private[client] class SparkResultCloseable( + resultMap: mutable.Map[Int, (Long, Seq[ArrowMessage])], + responses: CloseableIterator[proto.ExecutePlanResponse]) extends AutoCloseable { - override def close(): Unit = resultMap.values.foreach(_._2.foreach(_.close())) + override def close(): Unit = { + resultMap.values.foreach(_._2.foreach(_.close())) + responses.close() + } } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala index 509ceffc55282..55dd640f1b6b1 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema +import org.apache.spark.sql.connect.client.CloseableIterator import org.apache.spark.sql.errors.{CompilationErrors, ExecutionErrors} import org.apache.spark.sql.types.Decimal diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala index ed27336985416..b9badc5c936fa 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala @@ -40,8 +40,6 @@ private[arrow] object ArrowEncoderUtils { } } -trait CloseableIterator[E] extends Iterator[E] with AutoCloseable - private[arrow] object StructVectors { def unapply(v: AnyRef): Option[(StructVector, Seq[FieldVector])] = v match { case root: VectorSchemaRoot => Option((null, root.getFieldVectors.asScala.toSeq)) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala index c4a2cfa8a850f..9e67522711c6e 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.DefinedByConstructorParams import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.util.{SparkDateTimeUtils, SparkIntervalUtils} +import org.apache.spark.sql.connect.client.CloseableIterator import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.types.Decimal import org.apache.spark.sql.util.ArrowUtils diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala index e2f3be02ad3ae..404bd1b078ba4 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala @@ -75,34 +75,6 @@ abstract class StreamingQueryListener extends Serializable { def onQueryTerminated(event: QueryTerminatedEvent): Unit } -/** - * Py4J allows a pure interface so this proxy is required. - */ -private[spark] trait PythonStreamingQueryListener { - import StreamingQueryListener._ - - def onQueryStarted(event: QueryStartedEvent): Unit - - def onQueryProgress(event: QueryProgressEvent): Unit - - def onQueryIdle(event: QueryIdleEvent): Unit - - def onQueryTerminated(event: QueryTerminatedEvent): Unit -} - -private[spark] class PythonStreamingQueryListenerWrapper(listener: PythonStreamingQueryListener) - extends StreamingQueryListener { - import StreamingQueryListener._ - - def onQueryStarted(event: QueryStartedEvent): Unit = listener.onQueryStarted(event) - - def onQueryProgress(event: QueryProgressEvent): Unit = listener.onQueryProgress(event) - - override def onQueryIdle(event: QueryIdleEvent): Unit = listener.onQueryIdle(event) - - def onQueryTerminated(event: QueryTerminatedEvent): Unit = listener.onQueryTerminated(event) -} - /** * Companion object of [[StreamingQueryListener]] that defines the listener events. * @since 3.5.0 diff --git a/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java b/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java index c8210a7a485b1..6e5fb72d4964b 100644 --- a/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java +++ b/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java @@ -16,21 +16,26 @@ */ package org.apache.spark.sql; +import java.io.Serializable; +import java.math.BigDecimal; +import java.util.Arrays; +import java.util.List; + import org.junit.*; import static org.junit.Assert.*; import static org.apache.spark.sql.Encoders.*; import static org.apache.spark.sql.functions.*; +import static org.apache.spark.sql.RowFactory.create; import org.apache.spark.sql.connect.client.SparkConnectClient; import org.apache.spark.sql.connect.client.util.SparkConnectServerUtils; - -import java.math.BigDecimal; -import java.util.Arrays; +import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.sql.types.StructType; /** * Tests for the encoders class. */ -public class JavaEncoderSuite { +public class JavaEncoderSuite implements Serializable { private static SparkSession spark; @BeforeClass @@ -91,4 +96,22 @@ public void testSimpleEncoders() { dataset(DECIMAL(), bigDec(1000, 2), bigDec(2, 2)) .select(sum(v)).as(DECIMAL()).head().setScale(2)); } + + @Test + public void testRowEncoder() { + final StructType schema = new StructType() + .add("a", "int") + .add("b", "string"); + final Dataset df = spark.range(3) + .map(new MapFunction() { + @Override + public Row call(Long i) { + return create(i.intValue(), "s" + i); + } + }, + Encoders.row(schema)) + .filter(col("a").geq(1)); + final List expected = Arrays.asList(create(1, "s1"), create(2, "s2")); + Assert.assertEquals(expected, df.collectAsList()); + } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala index 00a6bcc9b5c45..fa97498f7e77a 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala @@ -46,7 +46,7 @@ class CatalogSuite extends RemoteSparkSession with SQLHelper { assert(databasesWithPattern.length == 0) val database = spark.catalog.getDatabase(db) assert(database.name == db) - val message = intercept[SparkException] { + val message = intercept[AnalysisException] { spark.catalog.getDatabase("notExists") }.getMessage assert(message.contains("SCHEMA_NOT_FOUND")) @@ -141,7 +141,7 @@ class CatalogSuite extends RemoteSparkSession with SQLHelper { assert(spark.catalog.listTables().collect().map(_.name).toSet == Set(parquetTableName)) } } - val message = intercept[SparkException] { + val message = intercept[AnalysisException] { spark.catalog.getTable(parquetTableName) }.getMessage assert(message.contains("TABLE_OR_VIEW_NOT_FOUND")) @@ -207,7 +207,7 @@ class CatalogSuite extends RemoteSparkSession with SQLHelper { assert(spark.catalog.getFunction(absFunctionName).name == absFunctionName) val notExistsFunction = "notExists" assert(!spark.catalog.functionExists(notExistsFunction)) - val message = intercept[SparkException] { + val message = intercept[AnalysisException] { spark.catalog.getFunction(notExistsFunction) }.getMessage assert(message.contains("UNRESOLVED_ROUTINE")) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 1403d460b516f..ebd3d037bba5c 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -43,6 +43,12 @@ import org.apache.spark.sql.types._ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateMethodTester { + test("throw ParseException") { + intercept[ParseException] { + spark.sql("selet 1").collect() + } + } + test("spark deep recursion") { var df = spark.range(1) for (a <- 1 to 500) { @@ -88,7 +94,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM assume(IntegrationTestUtils.isSparkHiveJarAvailable) withTable("test_martin") { // Fails, because table does not exist. - assertThrows[SparkException] { + assertThrows[AnalysisException] { spark.sql("select * from test_martin").collect() } // Execute eager, DML @@ -177,7 +183,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM StructField("job", StringType) :: Nil)) .csv(testDataPath.toString) // Failed because the path cannot be provided both via option and load method (csv). - assertThrows[SparkException] { + assertThrows[AnalysisException] { df.collect() } } @@ -381,7 +387,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM val df = spark.range(10) val outputFolderPath = Files.createTempDirectory("output").toAbsolutePath // Failed because the path cannot be provided both via option and save method. - assertThrows[SparkException] { + assertThrows[AnalysisException] { df.write.option("path", outputFolderPath.toString).save(outputFolderPath.toString) } } @@ -755,7 +761,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM private def checkSameResult[E](expected: scala.collection.Seq[E], dataset: Dataset[E]): Unit = { dataset.withResult { result => - assert(expected === result.iterator.asScala.toBuffer) + assert(expected === result.iterator.toBuffer) } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala index 525a5902525ad..ac64d4411a866 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ -import org.apache.spark.SparkException import org.apache.spark.sql.connect.client.util.QueryTest import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types.{StringType, StructType} @@ -279,7 +278,7 @@ class DataFrameNaFunctionSuite extends QueryTest with SQLHelper { test("drop with col(*)") { val df = createDF() - val ex = intercept[SparkException] { + val ex = intercept[AnalysisException] { df.na.drop("any", Seq("*")).collect() } assert(ex.getMessage.contains("UNRESOLVED_COLUMN.WITH_SUGGESTION")) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala index ad75887a7e2db..380ca2fb72b31 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import java.sql.Timestamp import java.util.Arrays -import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Append import org.apache.spark.sql.connect.client.util.QueryTest import org.apache.spark.sql.functions._ @@ -179,7 +178,7 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { assert(values == Arrays.asList[String]("0", "8,6,4,2,0", "1", "9,7,5,3,1")) // Star is not allowed as group sort column - val message = intercept[SparkException] { + val message = intercept[AnalysisException] { grouped .flatMapSortedGroups(col("*")) { (g, iter) => Iterator(String.valueOf(g), iter.mkString(",")) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala index 97fb46bf48af4..f06744399f833 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala @@ -16,6 +16,10 @@ */ package org.apache.spark.sql +import java.util.concurrent.{Executors, Phaser} + +import scala.util.control.NonFatal + import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, MethodDescriptor} import org.apache.spark.sql.connect.client.util.ConnectFunSuite @@ -24,6 +28,10 @@ import org.apache.spark.sql.connect.client.util.ConnectFunSuite * Tests for non-dataframe related SparkSession operations. */ class SparkSessionSuite extends ConnectFunSuite { + private val connectionString1: String = "sc://test.it:17845" + private val connectionString2: String = "sc://test.me:14099" + private val connectionString3: String = "sc://doit:16845" + test("default") { val session = SparkSession.builder().getOrCreate() assert(session.client.configuration.host == "localhost") @@ -32,16 +40,15 @@ class SparkSessionSuite extends ConnectFunSuite { } test("remote") { - val session = SparkSession.builder().remote("sc://test.me:14099").getOrCreate() + val session = SparkSession.builder().remote(connectionString2).getOrCreate() assert(session.client.configuration.host == "test.me") assert(session.client.configuration.port == 14099) session.close() } test("getOrCreate") { - val connectionString = "sc://test.it:17865" - val session1 = SparkSession.builder().remote(connectionString).getOrCreate() - val session2 = SparkSession.builder().remote(connectionString).getOrCreate() + val session1 = SparkSession.builder().remote(connectionString1).getOrCreate() + val session2 = SparkSession.builder().remote(connectionString1).getOrCreate() try { assert(session1 eq session2) } finally { @@ -51,9 +58,8 @@ class SparkSessionSuite extends ConnectFunSuite { } test("create") { - val connectionString = "sc://test.it:17845" - val session1 = SparkSession.builder().remote(connectionString).create() - val session2 = SparkSession.builder().remote(connectionString).create() + val session1 = SparkSession.builder().remote(connectionString1).create() + val session2 = SparkSession.builder().remote(connectionString1).create() try { assert(session1 ne session2) assert(session1.client.configuration == session2.client.configuration) @@ -64,8 +70,7 @@ class SparkSessionSuite extends ConnectFunSuite { } test("newSession") { - val connectionString = "sc://doit:16845" - val session1 = SparkSession.builder().remote(connectionString).create() + val session1 = SparkSession.builder().remote(connectionString3).create() val session2 = session1.newSession() try { assert(session1 ne session2) @@ -92,5 +97,126 @@ class SparkSessionSuite extends ConnectFunSuite { assertThrows[RuntimeException] { session.range(10).count() } + session.close() + } + + test("Default/Active session") { + // Make sure we start with a clean slate. + SparkSession.clearDefaultSession() + SparkSession.clearActiveSession() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.isEmpty) + intercept[IllegalStateException](SparkSession.active) + + // Create a session + val session1 = SparkSession.builder().remote(connectionString1).getOrCreate() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session1)) + assert(SparkSession.active == session1) + + // Create another session... + val session2 = SparkSession.builder().remote(connectionString2).create() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session1)) + SparkSession.setActiveSession(session2) + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + + // Clear sessions + SparkSession.clearDefaultSession() + assert(SparkSession.getDefaultSession.isEmpty) + SparkSession.clearActiveSession() + assert(SparkSession.getDefaultSession.isEmpty) + + // Flip sessions + SparkSession.setActiveSession(session1) + SparkSession.setDefaultSession(session2) + assert(SparkSession.getDefaultSession.contains(session2)) + assert(SparkSession.getActiveSession.contains(session1)) + + // Close session1 + session1.close() + assert(SparkSession.getDefaultSession.contains(session2)) + assert(SparkSession.getActiveSession.isEmpty) + + // Close session2 + session2.close() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.isEmpty) + } + + test("active session in multiple threads") { + SparkSession.clearDefaultSession() + SparkSession.clearActiveSession() + val session1 = SparkSession.builder().remote(connectionString1).create() + val session2 = SparkSession.builder().remote(connectionString1).create() + SparkSession.setActiveSession(session2) + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + + val phaser = new Phaser(2) + val executor = Executors.newFixedThreadPool(2) + def execute(block: Phaser => Unit): java.util.concurrent.Future[Boolean] = { + executor.submit[Boolean] { () => + try { + block(phaser) + true + } catch { + case NonFatal(e) => + phaser.forceTermination() + throw e + } + } + } + + try { + val script1 = execute { phaser => + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + session1.close() + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.contains(session2)) + SparkSession.clearActiveSession() + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.isEmpty) + } + val script2 = execute { phaser => + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + SparkSession.clearActiveSession() + val internalSession = SparkSession.builder().remote(connectionString3).getOrCreate() + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(internalSession)) + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.contains(internalSession)) + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.contains(internalSession)) + internalSession.close() + assert(SparkSession.getActiveSession.isEmpty) + } + assert(script1.get()) + assert(script2.get()) + assert(SparkSession.getActiveSession.contains(session2)) + session2.close() + assert(SparkSession.getActiveSession.isEmpty) + } finally { + executor.shutdown() + } } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 6e577e0f21257..2bf9c41fb2cbd 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -207,8 +207,6 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.RelationalGroupedDataset.apply"), // SparkSession - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.clearDefaultSession"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.setDefaultSession"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sparkContext"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sharedState"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sessionState"), diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index dd0e9347ac88b..7a8e8465a70cc 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_SECOND import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils._ import org.apache.spark.sql.catalyst.util.SparkIntervalUtils._ +import org.apache.spark.sql.connect.client.CloseableIterator import org.apache.spark.sql.connect.client.arrow.FooEnum.FooEnum import org.apache.spark.sql.connect.client.util.ConnectFunSuite import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType, Decimal, DecimalType, IntegerType, Metadata, SQLUserDefinedType, StructType, UserDefinedType, YearMonthIntervalType} diff --git a/connector/connect/common/src/main/protobuf/spark/connect/base.proto b/connector/connect/common/src/main/protobuf/spark/connect/base.proto index 151e828b3e903..79dbadba5bb07 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto @@ -750,9 +750,7 @@ message ReleaseExecuteRequest { optional string client_type = 4; // Release and close operation completely. - // Note: This should be called when the server side operation is finished, and ExecutePlan or - // ReattachExecute are finished processing the result stream, or inside onComplete / onError. - // This will not interrupt a running execution, but block until it's finished. + // This will also interrupt the query if it is running execution, and wait for it to be torn down. message ReleaseAll {} // Release all responses from the operation response stream up to and including diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala index 662288177dc69..930ccae5d4c76 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala @@ -46,6 +46,8 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends private var completed: Boolean = false + private val lock = new Object + /** Launches the execution in a background thread, returns immediately. */ def start(): Unit = { executionThread.start() @@ -62,7 +64,7 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends * true if it was not interrupted before, false if it was already interrupted or completed. */ def interrupt(): Boolean = { - synchronized { + lock.synchronized { if (!interrupted && !completed) { // checking completed prevents sending interrupt onError after onCompleted interrupted = true @@ -119,7 +121,7 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends // Inner executeInternal is wrapped by execute() for error handling. private def executeInternal() = { // synchronized - check if already got interrupted while starting. - synchronized { + lock.synchronized { if (interrupted) { throw new InterruptedException() } @@ -160,14 +162,23 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends s"${executeHolder.request.getPlan.getOpTypeCase} not supported.") } - if (executeHolder.reattachable) { - // Reattachable execution sends a ResultComplete at the end of the stream - // to signal that there isn't more coming. - executeHolder.responseObserver.onNext(createResultComplete()) - } - synchronized { - // Prevent interrupt after onCompleted, and throwing error to an already closed stream. - completed = true + lock.synchronized { + // Synchronized before sending ResultComplete, and up until completing the result stream + // to prevent a situation in which a client of reattachable execution receives + // ResultComplete, and proceeds to send ReleaseExecute, and that triggers an interrupt + // before it finishes. + + if (interrupted) { + // check if it got interrupted at the very last moment + throw new InterruptedException() + } + completed = true // no longer interruptible + + if (executeHolder.reattachable) { + // Reattachable execution sends a ResultComplete at the end of the stream + // to signal that there isn't more coming. + executeHolder.responseObserver.onNext(createResultComplete()) + } executeHolder.responseObserver.onCompleted() } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index f4b33ae961a2f..f70a17e580a3e 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.connect.planner -import java.io.IOException - import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Try +import com.google.common.base.Throwables import com.google.common.collect.{Lists, Maps} import com.google.protobuf.{Any => ProtoAny, ByteString} import io.grpc.{Context, Status, StatusRuntimeException} @@ -1518,11 +1517,15 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { logDebug(s"Unpack using class loader: ${Utils.getContextOrSparkClassLoader}") Utils.deserialize[T](fun.getPayload.toByteArray, Utils.getContextOrSparkClassLoader) } catch { - case e: IOException if e.getCause.isInstanceOf[NoSuchMethodException] => - throw new ClassNotFoundException( - s"Failed to load class correctly due to ${e.getCause}. " + - "Make sure the artifact where the class is defined is installed by calling" + - " session.addArtifact.") + case t: Throwable => + Throwables.getRootCause(t) match { + case nsm: NoSuchMethodException => + throw new ClassNotFoundException( + s"Failed to load class correctly due to $nsm. " + + "Make sure the artifact where the class is defined is installed by calling" + + " session.addArtifact.") + case _ => throw t + } } } @@ -3097,10 +3100,14 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { case StreamingQueryManagerCommand.CommandCase.REMOVE_LISTENER => val listenerId = command.getRemoveListener.getId - val listener: StreamingQueryListener = sessionHolder.getListenerOrThrow(listenerId) - session.streams.removeListener(listener) - sessionHolder.removeCachedListener(listenerId) - respBuilder.setRemoveListener(true) + sessionHolder.getListener(listenerId) match { + case Some(listener) => + session.streams.removeListener(listener) + sessionHolder.removeCachedListener(listenerId) + respBuilder.setRemoveListener(true) + case None => + respBuilder.setRemoveListener(false) + } case StreamingQueryManagerCommand.CommandCase.LIST_LISTENERS => respBuilder.getListListenersBuilder diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala index 998faf327d03a..4f1037b86c9f2 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala @@ -87,11 +87,13 @@ object StreamingForeachBatchHelper extends Logging { val port = SparkConnectService.localPort val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}" - val runner = StreamingPythonRunner(pythonFn, connectUrl) + val runner = StreamingPythonRunner( + pythonFn, + connectUrl, + sessionHolder.sessionId, + "pyspark.sql.connect.streaming.worker.foreachBatch_worker") val (dataOut, dataIn) = - runner.init( - sessionHolder.sessionId, - "pyspark.sql.connect.streaming.worker.foreachBatch_worker") + runner.init() val foreachBatchRunnerFn: FnArgsWithId => Unit = (args: FnArgsWithId) => { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala index d915bc9349609..9b2a931ec4acb 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.streaming.StreamingQueryListener /** * A helper class for handling StreamingQueryListener related functionality in Spark Connect. Each * instance of this class starts a python process, inside which has the python handling logic. - * When new a event is received, it is serialized to json, and passed to the python process. + * When a new event is received, it is serialized to json, and passed to the python process. */ class PythonStreamingQueryListener( listener: SimplePythonFunction, @@ -32,12 +32,15 @@ class PythonStreamingQueryListener( pythonExec: String) extends StreamingQueryListener { - val port = SparkConnectService.localPort - val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}" - val runner = StreamingPythonRunner(listener, connectUrl) + private val port = SparkConnectService.localPort + private val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}" + private val runner = StreamingPythonRunner( + listener, + connectUrl, + sessionHolder.sessionId, + "pyspark.sql.connect.streaming.worker.listener_worker") - val (dataOut, _) = - runner.init(sessionHolder.sessionId, "pyspark.sql.connect.streaming.worker.listener_worker") + val (dataOut, _) = runner.init() override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = { PythonRDD.writeUTF(event.json, dataOut) @@ -63,7 +66,7 @@ class PythonStreamingQueryListener( dataOut.flush() } - // TODO(SPARK-44433)(SPARK-44516): Improve termination of Processes. - // Similar to foreachBatch when we need to exit the process when the query ends. - // In listener semantics, we need to exit the process when removeListener is called. + private[spark] def stopListenerProcess(): Unit = { + runner.stop() + } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala index a49c0a8bacf98..4eb90f9f1639a 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala @@ -156,11 +156,11 @@ private[connect] class ExecuteHolder( } /** - * Close the execution and remove it from the session. Note: It blocks joining the - * ExecuteThreadRunner thread, so it assumes that it's called when the execution is ending or - * ended. If it is desired to kill the execution, interrupt() should be called first. + * Close the execution and remove it from the session. Note: it first interrupts the runner if + * it's still running, and it waits for it to finish. */ def close(): Unit = { + runner.interrupt() runner.join() eventsManager.postClosed() sessionHolder.removeExecuteHolder(operationId) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 310bb9208c21d..29134f0dc0ded 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager import org.apache.spark.sql.connect.common.InvalidPlanInput +import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener import org.apache.spark.sql.streaming.StreamingQueryListener import org.apache.spark.util.{SystemClock} import org.apache.spark.util.Utils @@ -220,20 +221,22 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio } /** - * Returns [[StreamingQueryListener]] cached for Listener ID `id`. If it is not found, throw - * [[InvalidPlanInput]]. + * Returns [[StreamingQueryListener]] cached for Listener ID `id`. If it is not found, return + * None. */ - private[connect] def getListenerOrThrow(id: String): StreamingQueryListener = { + private[connect] def getListener(id: String): Option[StreamingQueryListener] = { Option(listenerCache.get(id)) - .getOrElse { - throw InvalidPlanInput(s"No listener with id $id is found in the session $sessionId") - } } /** - * Removes corresponding StreamingQueryListener by ID. + * Removes corresponding StreamingQueryListener by ID. Terminates the python process if it's a + * Spark Connect PythonStreamingQueryListener. */ - private[connect] def removeCachedListener(id: String): StreamingQueryListener = { + private[connect] def removeCachedListener(id: String): Unit = { + listenerCache.get(id) match { + case pyListener: PythonStreamingQueryListener => pyListener.stopListenerProcess() + case _ => // do nothing + } listenerCache.remove(id) } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index c29a9b9b62958..e833d12c4f595 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -493,24 +493,37 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with .setSessionId(sessionId) .build() - // The observer is executed inside this thread. So - // we can perform the checks inside the observer. + // Even though the observer is executed inside this thread, this thread is also executing + // the SparkConnectService. If we throw an exception inside it, it will be caught by + // the ErrorUtils.handleError wrapping instance.executePlan and turned into an onError + // call with StatusRuntimeException, which will be eaten here. + var failures: mutable.ArrayBuffer[String] = new mutable.ArrayBuffer[String]() instance.executePlan( request, new StreamObserver[proto.ExecutePlanResponse] { override def onNext(v: proto.ExecutePlanResponse): Unit = { - fail("this should not receive responses") + // The query receives some pre-execution responses such as schema, but should + // never proceed to execution and get query results. + if (v.hasArrowBatch) { + failures += s"this should not receive query results but got $v" + } } override def onError(throwable: Throwable): Unit = { - assert(throwable.isInstanceOf[StatusRuntimeException]) - verifyEvents.onError(throwable) + try { + assert(throwable.isInstanceOf[StatusRuntimeException]) + verifyEvents.onError(throwable) + } catch { + case t: Throwable => + failures += s"assertion $t validating processing onError($throwable)." + } } override def onCompleted(): Unit = { - fail("this should not complete") + failures += "this should not complete" } }) + assert(failures.isEmpty, s"this should have no failures but got $failures") verifyEvents.onCompleted() } } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 47ac3df4cc62c..3495536a3508f 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -171,7 +171,10 @@ private class ShuffleStatus( * Get the map output that corresponding to a given mapId. */ def getMapStatus(mapId: Long): Option[MapStatus] = withReadLock { - mapIdToMapIndex.get(mapId).map(mapStatuses(_)) + mapIdToMapIndex.get(mapId).map(mapStatuses(_)) match { + case Some(null) => None + case m => m + } } /** diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 813a14acd19e4..8c054d24b10d7 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -503,8 +503,6 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria logWarning(msg) } - val executorOptsKey = EXECUTOR_JAVA_OPTIONS.key - // Used by Yarn in 1.1 and before sys.props.get("spark.driver.libraryPath").foreach { value => val warning = @@ -518,16 +516,19 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria } // Validate spark.executor.extraJavaOptions - getOption(executorOptsKey).foreach { javaOpts => - if (javaOpts.contains("-Dspark")) { - val msg = s"$executorOptsKey is not allowed to set Spark options (was '$javaOpts'). " + - "Set them directly on a SparkConf or in a properties file when using ./bin/spark-submit." - throw new Exception(msg) - } - if (javaOpts.contains("-Xmx")) { - val msg = s"$executorOptsKey is not allowed to specify max heap memory settings " + - s"(was '$javaOpts'). Use spark.executor.memory instead." - throw new Exception(msg) + Seq(EXECUTOR_JAVA_OPTIONS.key, "spark.executor.defaultJavaOptions").foreach { executorOptsKey => + getOption(executorOptsKey).foreach { javaOpts => + if (javaOpts.contains("-Dspark")) { + val msg = s"$executorOptsKey is not allowed to set Spark options (was '$javaOpts'). " + + "Set them directly on a SparkConf or in a properties file " + + "when using ./bin/spark-submit." + throw new Exception(msg) + } + if (javaOpts.contains("-Xmx")) { + val msg = s"$executorOptsKey is not allowed to specify max heap memory settings " + + s"(was '$javaOpts'). Use spark.executor.memory instead." + throw new Exception(msg) + } } } diff --git a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala index d4fd9485675fa..a079743c847ae 100644 --- a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala @@ -29,27 +29,36 @@ import org.apache.spark.internal.config.Python.{PYTHON_AUTH_SOCKET_TIMEOUT, PYTH private[spark] object StreamingPythonRunner { - def apply(func: PythonFunction, connectUrl: String): StreamingPythonRunner = { - new StreamingPythonRunner(func, connectUrl) + def apply( + func: PythonFunction, + connectUrl: String, + sessionId: String, + workerModule: String + ): StreamingPythonRunner = { + new StreamingPythonRunner(func, connectUrl, sessionId, workerModule) } } -private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: String) - extends Logging { +private[spark] class StreamingPythonRunner( + func: PythonFunction, + connectUrl: String, + sessionId: String, + workerModule: String) extends Logging { private val conf = SparkEnv.get.conf protected val bufferSize: Int = conf.get(BUFFER_SIZE) protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT) private val envVars: java.util.Map[String, String] = func.envVars private val pythonExec: String = func.pythonExec + private var pythonWorker: Option[Socket] = None protected val pythonVer: String = func.pythonVer /** * Initializes the Python worker for streaming functions. Sets up Spark Connect session * to be used with the functions. */ - def init(sessionId: String, workerModule: String): (DataOutputStream, DataInputStream) = { - logInfo(s"Initializing Python runner (session: $sessionId ,pythonExec: $pythonExec") + def init(): (DataOutputStream, DataInputStream) = { + logInfo(s"Initializing Python runner (session: $sessionId, pythonExec: $pythonExec") val env = SparkEnv.get val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") @@ -57,14 +66,19 @@ private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: Str envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString) envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString) - conf.set(PYTHON_USE_DAEMON, false) envVars.put("SPARK_CONNECT_LOCAL_URL", connectUrl) - val pythonWorkerFactory = - new PythonWorkerFactory(pythonExec, workerModule, envVars.asScala.toMap) - val (worker: Socket, _) = pythonWorkerFactory.createSimpleWorker() - - val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) + val prevConf = conf.get(PYTHON_USE_DAEMON) + conf.set(PYTHON_USE_DAEMON, false) + try { + val (worker, _) = env.createPythonWorker( + pythonExec, workerModule, envVars.asScala.toMap) + pythonWorker = Some(worker) + } finally { + conf.set(PYTHON_USE_DAEMON, prevConf) + } + + val stream = new BufferedOutputStream(pythonWorker.get.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) // TODO(SPARK-44461): verify python version @@ -78,11 +92,21 @@ private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: Str dataOut.write(command.toArray) dataOut.flush() - val dataIn = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) + val dataIn = new DataInputStream( + new BufferedInputStream(pythonWorker.get.getInputStream, bufferSize)) val resFromPython = dataIn.readInt() logInfo(s"Runner initialization returned $resFromPython") (dataOut, dataIn) } + + /** + * Stops the Python worker. + */ + def stop(): Unit = { + pythonWorker.foreach { worker => + SparkEnv.get.destroyPythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker) + } + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/EventLogFileCompactor.scala b/core/src/main/scala/org/apache/spark/deploy/history/EventLogFileCompactor.scala index 8558f765175fc..27040e83533ff 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/EventLogFileCompactor.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/EventLogFileCompactor.scala @@ -149,6 +149,7 @@ class EventLogFileCompactor( val logWriter = new CompactedEventLogFileWriter(lastIndexEventLogPath, "dummy", None, lastIndexEventLogPath.getParent.toUri, sparkConf, hadoopConf) + val startTime = System.currentTimeMillis() logWriter.start() eventLogFiles.foreach { file => EventFilter.applyFilterToFile(fs, filters, file.getPath, @@ -158,6 +159,8 @@ class EventLogFileCompactor( ) } logWriter.stop() + val duration = System.currentTimeMillis() - startTime + logInfo(s"Finished rewriting eventLog files to ${logWriter.logPath} took $duration ms.") logWriter.logPath } diff --git a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala index 7d2923fdf3752..d09abff2773b8 100644 --- a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala @@ -140,9 +140,9 @@ private[serializer] class GenericAvroSerializer[D <: GenericContainer] case Some(s) => new Schema.Parser().setValidateDefaults(false).parse(s) case None => throw new SparkException( - "Error reading attempting to read avro data -- encountered an unknown " + - s"fingerprint: $fingerprint, not sure what schema to use. This could happen " + - "if you registered additional schemas after starting your spark context.") + errorClass = "ERROR_READING_AVRO_UNKNOWN_FINGERPRINT", + messageParameters = Map("fingerprint" -> fingerprint.toString), + cause = null) } }) } else { diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 826d6789f88ee..f75942cbb879f 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -66,15 +66,21 @@ class KryoSerializer(conf: SparkConf) private val bufferSizeKb = conf.get(KRYO_SERIALIZER_BUFFER_SIZE) if (bufferSizeKb >= ByteUnit.GiB.toKiB(2)) { - throw new IllegalArgumentException(s"${KRYO_SERIALIZER_BUFFER_SIZE.key} must be less than " + - s"2048 MiB, got: + ${ByteUnit.KiB.toMiB(bufferSizeKb)} MiB.") + throw new SparkIllegalArgumentException( + errorClass = "INVALID_KRYO_SERIALIZER_BUFFER_SIZE", + messageParameters = Map( + "bufferSizeConfKey" -> KRYO_SERIALIZER_BUFFER_SIZE.key, + "bufferSizeConfValue" -> ByteUnit.KiB.toMiB(bufferSizeKb).toString)) } private val bufferSize = ByteUnit.KiB.toBytes(bufferSizeKb).toInt val maxBufferSizeMb = conf.get(KRYO_SERIALIZER_MAX_BUFFER_SIZE).toInt if (maxBufferSizeMb >= ByteUnit.GiB.toMiB(2)) { - throw new IllegalArgumentException(s"${KRYO_SERIALIZER_MAX_BUFFER_SIZE.key} must be less " + - s"than 2048 MiB, got: $maxBufferSizeMb MiB.") + throw new SparkIllegalArgumentException( + errorClass = "INVALID_KRYO_SERIALIZER_BUFFER_SIZE", + messageParameters = Map( + "bufferSizeConfKey" -> KRYO_SERIALIZER_MAX_BUFFER_SIZE.key, + "bufferSizeConfValue" -> maxBufferSizeMb.toString)) } private val maxBufferSize = ByteUnit.MiB.toBytes(maxBufferSizeMb).toInt @@ -183,7 +189,10 @@ class KryoSerializer(conf: SparkConf) .foreach { reg => reg.registerClasses(kryo) } } catch { case e: Exception => - throw new SparkException(s"Failed to register classes with Kryo", e) + throw new SparkException( + errorClass = "FAILED_REGISTER_CLASS_WITH_KRYO", + messageParameters = Map.empty, + cause = e) } } @@ -442,8 +451,12 @@ private[spark] class KryoSerializerInstance( kryo.writeClassAndObject(output, t) } catch { case e: KryoException if e.getMessage.startsWith("Buffer overflow") => - throw new SparkException(s"Kryo serialization failed: ${e.getMessage}. To avoid this, " + - s"increase ${KRYO_SERIALIZER_MAX_BUFFER_SIZE.key} value.", e) + throw new SparkException( + errorClass = "KRYO_BUFFER_OVERFLOW", + messageParameters = Map( + "exceptionMsg" -> e.getMessage, + "bufferSizeConfKey" -> KRYO_SERIALIZER_MAX_BUFFER_SIZE.key), + cause = e) } finally { releaseKryo(kryo) } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index c1708c320c5d4..726622673650d 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -98,7 +98,7 @@ private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends {rdd.storageLevel} {rdd.numCachedPartitions.toString} - {"%.0f%%".format(rdd.numCachedPartitions * 100.0 / rdd.numPartitions)} + {"%.2f%%".format(rdd.numCachedPartitions * 100.0 / rdd.numPartitions)} {Utils.bytesToString(rdd.memoryUsed)} {Utils.bytesToString(rdd.diskUsed)} diff --git a/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java b/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java index ef7c4cbbb799c..4bfb4a2c68c40 100644 --- a/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java +++ b/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.io; import org.apache.commons.io.FileUtils; -import org.apache.commons.lang3.RandomUtils; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -25,6 +24,7 @@ import java.io.File; import java.io.IOException; import java.io.InputStream; +import java.util.concurrent.ThreadLocalRandom; import static org.junit.Assert.assertEquals; @@ -33,7 +33,8 @@ */ public abstract class GenericFileInputStreamSuite { - private byte[] randomBytes; + // Create a byte array of size 2 MB with random bytes + private byte[] randomBytes = new byte[2 * 1024 * 1024]; protected File inputFile; @@ -41,8 +42,7 @@ public abstract class GenericFileInputStreamSuite { @Before public void setUp() throws IOException { - // Create a byte array of size 2 MB with random bytes - randomBytes = RandomUtils.nextBytes(2 * 1024 * 1024); + ThreadLocalRandom.current().nextBytes(randomBytes); inputFile = File.createTempFile("temp-file", ".tmp"); FileUtils.writeByteArrayToFile(inputFile, randomBytes); } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 7ac3d0092c8ce..450ff01921a83 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -1083,4 +1083,30 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { rpcEnv.shutdown() } } + + test("SPARK-44658: ShuffleStatus.getMapStatus should return None") { + val bmID = BlockManagerId("a", "hostA", 1000) + val mapStatus = MapStatus(bmID, Array(1000L, 10000L), mapTaskId = 0) + val shuffleStatus = new ShuffleStatus(1000) + shuffleStatus.addMapOutput(mapIndex = 1, mapStatus) + shuffleStatus.removeMapOutput(mapIndex = 1, bmID) + assert(shuffleStatus.getMapStatus(0).isEmpty) + } + + test("SPARK-44661: getMapOutputLocation should not throw NPE") { + val rpcEnv = createRpcEnv("test") + val tracker = newTrackerMaster() + try { + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + tracker.registerShuffle(0, 1, 1) + tracker.registerMapOutput(0, 0, MapStatus(BlockManagerId("exec-1", "hostA", 1000), + Array(2L), 0)) + tracker.removeOutputsOnHost("hostA") + assert(tracker.getMapOutputLocation(0, 0) == None) + } finally { + tracker.stop() + rpcEnv.shutdown() + } + } } diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 74fd78162218b..75e22e1418b4a 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -498,6 +498,20 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst } } } + + test("SPARK-44650: spark.executor.defaultJavaOptions Check illegal java options") { + val conf = new SparkConf() + conf.validateSettings() + conf.set(EXECUTOR_JAVA_OPTIONS.key, "-Dspark.foo=bar") + intercept[Exception] { + conf.validateSettings() + } + conf.remove(EXECUTOR_JAVA_OPTIONS.key) + conf.set("spark.executor.defaultJavaOptions", "-Dspark.foo=bar") + intercept[Exception] { + conf.validateSettings() + } + } } class Class1 {} diff --git a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala index 39607621b4c45..998ad21a50d25 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala @@ -19,8 +19,8 @@ package org.apache.spark.deploy.master import java.net.ServerSocket +import java.util.concurrent.ThreadLocalRandom -import org.apache.commons.lang3.RandomUtils import org.apache.curator.test.TestingServer import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} @@ -117,7 +117,7 @@ class PersistenceEngineSuite extends SparkFunSuite { } private def findFreePort(conf: SparkConf): Int = { - val candidatePort = RandomUtils.nextInt(1024, 65536) + val candidatePort = ThreadLocalRandom.current().nextInt(1024, 65536) Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { val socket = new ServerSocket(trialPort) socket.close() diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 905bb8110736d..3e69f01c09c46 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -18,10 +18,10 @@ package org.apache.spark.metrics import java.io.{File, PrintWriter} +import java.util.concurrent.ThreadLocalRandom import scala.collection.mutable.ArrayBuffer -import org.apache.commons.lang3.RandomUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{LongWritable, Text} @@ -54,7 +54,7 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext Utils.tryWithResource(new PrintWriter(tmpFile)) { pw => for (x <- 1 to numRecords) { // scalastyle:off println - pw.println(RandomUtils.nextInt(0, numBuckets)) + pw.println(ThreadLocalRandom.current().nextInt(0, numBuckets)) // scalastyle:on println } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index ecd66dc2c5fb0..dcb69f812a7db 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.storage import java.io.{File, InputStream, IOException} import java.nio.ByteBuffer import java.nio.file.Files +import java.util.concurrent.ThreadLocalRandom import scala.collection.JavaConverters._ import scala.collection.mutable @@ -31,7 +32,6 @@ import scala.reflect.ClassTag import scala.reflect.classTag import com.esotericsoftware.kryo.KryoException -import org.apache.commons.lang3.RandomUtils import org.mockito.{ArgumentCaptor, ArgumentMatchers => mc} import org.mockito.Mockito.{doAnswer, mock, never, spy, times, verify, when} import org.scalatest.PrivateMethodTester @@ -1887,7 +1887,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe (transCtx.createServer(port, Seq.empty[TransportServerBootstrap].asJava), port) } - val candidatePort = RandomUtils.nextInt(1024, 65536) + val candidatePort = ThreadLocalRandom.current().nextInt(1024, 65536) val (server, shufflePort) = Utils.startServiceOnPort(candidatePort, newShuffleServer, conf, "ShuffleServer") @@ -2274,7 +2274,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe (transCtx.createServer(port, Seq.empty[TransportServerBootstrap].asJava), port) } - val candidatePort = RandomUtils.nextInt(1024, 65536) + val candidatePort = ThreadLocalRandom.current().nextInt(1024, 65536) val (server, shufflePort) = Utils.startServiceOnPort(candidatePort, newShuffleServer, conf, "ShuffleServer") diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala index 718c6856cb31f..d1e25bf8a2346 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala @@ -48,8 +48,8 @@ class StoragePageSuite extends SparkFunSuite { val rdd2 = new RDDStorageInfo(2, "rdd2", - 10, - 5, + 1000, + 56, StorageLevel.DISK_ONLY.description, 0L, 200L, @@ -58,8 +58,8 @@ class StoragePageSuite extends SparkFunSuite { val rdd3 = new RDDStorageInfo(3, "rdd3", - 10, - 10, + 1000, + 103, StorageLevel.MEMORY_AND_DISK_SER.description, 400L, 500L, @@ -94,19 +94,20 @@ class StoragePageSuite extends SparkFunSuite { assert((xmlNodes \\ "tr").size === 3) assert(((xmlNodes \\ "tr")(0) \\ "td").map(_.text.trim) === - Seq("1", "rdd1", "Memory Deserialized 1x Replicated", "10", "100%", "100.0 B", "0.0 B")) + Seq("1", "rdd1", "Memory Deserialized 1x Replicated", "10", "100.00%", "100.0 B", "0.0 B")) // Check the url assert(((xmlNodes \\ "tr")(0) \\ "td" \ "a")(0).attribute("href").map(_.text) === Some("http://localhost:4040/storage/rdd/?id=1")) assert(((xmlNodes \\ "tr")(1) \\ "td").map(_.text.trim) === - Seq("2", "rdd2", "Disk Serialized 1x Replicated", "5", "50%", "0.0 B", "200.0 B")) + Seq("2", "rdd2", "Disk Serialized 1x Replicated", "56", "5.60%", "0.0 B", "200.0 B")) // Check the url assert(((xmlNodes \\ "tr")(1) \\ "td" \ "a")(0).attribute("href").map(_.text) === Some("http://localhost:4040/storage/rdd/?id=2")) assert(((xmlNodes \\ "tr")(2) \\ "td").map(_.text.trim) === - Seq("3", "rdd3", "Disk Memory Serialized 1x Replicated", "10", "100%", "400.0 B", "500.0 B")) + Seq("3", "rdd3", "Disk Memory Serialized 1x Replicated", "103", "10.30%", "400.0 B", + "500.0 B")) // Check the url assert(((xmlNodes \\ "tr")(2) \\ "td" \ "a")(0).attribute("href").map(_.text) === Some("http://localhost:4040/storage/rdd/?id=3")) diff --git a/dev/checkstyle-suppressions.xml b/dev/checkstyle-suppressions.xml index 44876fe69120d..8ba1ff1b3b1eb 100644 --- a/dev/checkstyle-suppressions.xml +++ b/dev/checkstyle-suppressions.xml @@ -57,9 +57,7 @@ - + files="sql/api/src/main/java/org/apache/spark/sql/streaming/Trigger.java"/> diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index e0588ae934cd2..59e3b69b349b8 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -220,7 +220,7 @@ git clean -d -f -x rm -f .gitignore cd .. -export MAVEN_OPTS="-Xss128m -Xmx12g" +export MAVEN_OPTS="-Xss128m -Xmx12g -XX:ReservedCodeCacheSize=1g" if [[ "$1" == "package" ]]; then # Source and binary tarballs diff --git a/dev/free_disk_space b/dev/free_disk_space index 87a09a524f4fd..2b2b20f814e02 100755 --- a/dev/free_disk_space +++ b/dev/free_disk_space @@ -34,7 +34,11 @@ sudo rm -rf /usr/local/share/powershell sudo rm -rf /usr/local/share/chromium sudo rm -rf /usr/local/lib/android sudo rm -rf /usr/local/lib/node_modules + sudo rm -rf /opt/az +sudo rm -rf /opt/hostedtoolcache/CodeQL +sudo rm -rf /opt/hostedtoolcache/go +sudo rm -rf /opt/hostedtoolcache/node sudo apt-get remove --purge -y '^aspnet.*' sudo apt-get remove --purge -y '^dotnet-.*' diff --git a/dev/free_disk_space_container b/dev/free_disk_space_container new file mode 100755 index 0000000000000..cc3b74643e4fa --- /dev/null +++ b/dev/free_disk_space_container @@ -0,0 +1,33 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +echo "==================================" +echo "Free up disk space on CI system" +echo "==================================" + +echo "Listing 100 largest packages" +dpkg-query -Wf '${Installed-Size}\t${Package}\n' | sort -n | tail -n 100 +df -h + +echo "Removing large packages" +rm -rf /__t/CodeQL +rm -rf /__t/go +rm -rf /__t/node + +df -h diff --git a/dev/infra/Dockerfile b/dev/infra/Dockerfile index 9d7b29e25b49b..b69e682f239c8 100644 --- a/dev/infra/Dockerfile +++ b/dev/infra/Dockerfile @@ -73,3 +73,5 @@ RUN python3.9 -m pip install grpcio protobuf googleapis-common-protos grpcio-sta # Add torch as a testing dependency for TorchDistributor RUN python3.9 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu RUN python3.9 -m pip install torcheval +# Add Deepspeed as a testing dependency for DeepspeedTorchDistributor +RUN python3.9 -m pip install deepspeed diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 9e45e0facefc1..c5be1957a7dcb 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -668,7 +668,6 @@ def __hash__(self): "pyspark.pandas.indexes.datetimes", "pyspark.pandas.indexes.timedelta", "pyspark.pandas.indexes.multi", - "pyspark.pandas.indexes.numeric", "pyspark.pandas.spark.accessors", "pyspark.pandas.spark.utils", "pyspark.pandas.typedef.typehints", @@ -817,10 +816,9 @@ def __hash__(self): pyspark_connect = Module( name="pyspark-connect", - dependencies=[pyspark_sql, pyspark_ml, connect], + dependencies=[pyspark_sql, connect], source_file_regexes=[ "python/pyspark/sql/connect", - "python/pyspark/ml/connect", ], python_test_goals=[ # sql doctests @@ -871,6 +869,21 @@ def __hash__(self): "pyspark.sql.tests.connect.test_parity_pandas_udf_scalar", "pyspark.sql.tests.connect.test_parity_pandas_udf_grouped_agg", "pyspark.sql.tests.connect.test_parity_pandas_udf_window", + ], + excluded_python_implementations=[ + "PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and + # they aren't available there + ], +) + + +pyspark_ml_connect = Module( + name="pyspark-ml-connect", + dependencies=[pyspark_connect, pyspark_ml], + source_file_regexes=[ + "python/pyspark/ml/connect", + ], + python_test_goals=[ # ml doctests "pyspark.ml.connect.functions", # ml unittests diff --git a/dev/sparktestsupport/utils.py b/dev/sparktestsupport/utils.py index 816c982bd60e9..e79d864c32095 100755 --- a/dev/sparktestsupport/utils.py +++ b/dev/sparktestsupport/utils.py @@ -112,25 +112,28 @@ def determine_modules_to_test(changed_modules, deduplicated=True): >>> sorted([x.name for x in determine_modules_to_test([modules.sql])]) ... # doctest: +NORMALIZE_WHITESPACE ['avro', 'connect', 'docker-integration-tests', 'examples', 'hive', 'hive-thriftserver', - 'mllib', 'protobuf', 'pyspark-connect', 'pyspark-ml', 'pyspark-mllib', 'pyspark-pandas', - 'pyspark-pandas-connect', 'pyspark-pandas-slow', 'pyspark-pandas-slow-connect', 'pyspark-sql', - 'pyspark-testing', 'repl', 'sparkr', 'sql', 'sql-kafka-0-10'] + 'mllib', 'protobuf', 'pyspark-connect', 'pyspark-ml', 'pyspark-ml-connect', 'pyspark-mllib', + 'pyspark-pandas', 'pyspark-pandas-connect', 'pyspark-pandas-slow', + 'pyspark-pandas-slow-connect', 'pyspark-sql', 'pyspark-testing', 'repl', 'sparkr', 'sql', + 'sql-kafka-0-10'] >>> sorted([x.name for x in determine_modules_to_test( ... [modules.sparkr, modules.sql], deduplicated=False)]) ... # doctest: +NORMALIZE_WHITESPACE ['avro', 'connect', 'docker-integration-tests', 'examples', 'hive', 'hive-thriftserver', - 'mllib', 'protobuf', 'pyspark-connect', 'pyspark-ml', 'pyspark-mllib', 'pyspark-pandas', - 'pyspark-pandas-connect', 'pyspark-pandas-slow', 'pyspark-pandas-slow-connect', 'pyspark-sql', - 'pyspark-testing', 'repl', 'sparkr', 'sql', 'sql-kafka-0-10'] + 'mllib', 'protobuf', 'pyspark-connect', 'pyspark-ml', 'pyspark-ml-connect', 'pyspark-mllib', + 'pyspark-pandas', 'pyspark-pandas-connect', 'pyspark-pandas-slow', + 'pyspark-pandas-slow-connect', 'pyspark-sql', 'pyspark-testing', 'repl', 'sparkr', 'sql', + 'sql-kafka-0-10'] >>> sorted([x.name for x in determine_modules_to_test( ... [modules.sql, modules.core], deduplicated=False)]) ... # doctest: +NORMALIZE_WHITESPACE ['avro', 'catalyst', 'connect', 'core', 'docker-integration-tests', 'examples', 'graphx', 'hive', 'hive-thriftserver', 'mllib', 'mllib-local', 'protobuf', 'pyspark-connect', - 'pyspark-core', 'pyspark-ml', 'pyspark-mllib', 'pyspark-pandas', 'pyspark-pandas-connect', - 'pyspark-pandas-slow', 'pyspark-pandas-slow-connect', 'pyspark-resource', 'pyspark-sql', - 'pyspark-streaming', 'pyspark-testing', 'repl', 'root', 'sparkr', 'sql', 'sql-kafka-0-10', - 'streaming', 'streaming-kafka-0-10', 'streaming-kinesis-asl'] + 'pyspark-core', 'pyspark-ml', 'pyspark-ml-connect', 'pyspark-mllib', 'pyspark-pandas', + 'pyspark-pandas-connect', 'pyspark-pandas-slow', 'pyspark-pandas-slow-connect', + 'pyspark-resource', 'pyspark-sql', 'pyspark-streaming', 'pyspark-testing', 'repl', + 'root', 'sparkr', 'sql', 'sql-kafka-0-10', 'streaming', 'streaming-kafka-0-10', + 'streaming-kinesis-asl'] """ modules_to_test = set() for module in changed_modules: diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index d47ff3987f95b..3e87edad0aadd 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -104,10 +104,12 @@ Once you've set up this file, you can launch or stop your cluster with the follo - `sbin/start-master.sh` - Starts a master instance on the machine the script is executed on. - `sbin/start-workers.sh` - Starts a worker instance on each machine specified in the `conf/workers` file. - `sbin/start-worker.sh` - Starts a worker instance on the machine the script is executed on. +- `sbin/start-connect-server.sh` - Starts a Spark Connect server on the machine the script is executed on. - `sbin/start-all.sh` - Starts both a master and a number of workers as described above. - `sbin/stop-master.sh` - Stops the master that was started via the `sbin/start-master.sh` script. - `sbin/stop-worker.sh` - Stops all worker instances on the machine the script is executed on. - `sbin/stop-workers.sh` - Stops all worker instances on the machines specified in the `conf/workers` file. +- `sbin/stop-connect-server.sh` - Stops all Spark Connect server instances on the machine the script is executed on. - `sbin/stop-all.sh` - Stops both the master and the workers as described above. Note that these scripts must be executed on the machine you want to run the Spark master on, not your local machine. diff --git a/docs/sql-error-conditions-duplicate-routine-parameter-assignment-error-class.md b/docs/sql-error-conditions-duplicate-routine-parameter-assignment-error-class.md index d9f14b5a55ef8..eb5ca2a0169d1 100644 --- a/docs/sql-error-conditions-duplicate-routine-parameter-assignment-error-class.md +++ b/docs/sql-error-conditions-duplicate-routine-parameter-assignment-error-class.md @@ -27,10 +27,10 @@ This error class has the following derived error classes: ## BOTH_POSITIONAL_AND_NAMED -A positional argument and named argument both referred to the same parameter. +A positional argument and named argument both referred to the same parameter. Please remove the named argument referring to this parameter. ## DOUBLE_NAMED_ARGUMENT_REFERENCE -More than one named argument referred to the same parameter. +More than one named argument referred to the same parameter. Please assign a value only once. diff --git a/docs/sql-error-conditions-malformed-record-in-parsing-error-class.md b/docs/sql-error-conditions-malformed-record-in-parsing-error-class.md index ab9582dffcd31..1cc0327af67ba 100644 --- a/docs/sql-error-conditions-malformed-record-in-parsing-error-class.md +++ b/docs/sql-error-conditions-malformed-record-in-parsing-error-class.md @@ -30,6 +30,10 @@ This error class has the following derived error classes: Parsing JSON arrays as structs is forbidden. +## CANNOT_PARSE_STRING_AS_DATATYPE + +Cannot parse the value `` of the field `` as target spark data type `` from the input type ``. + ## WITHOUT_SUGGESTION diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index 6ea16d7ef31b3..b59bb1789488e 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -484,6 +484,12 @@ SQLSTATE: none assigned Not found an encoder of the type `` to Spark SQL internal representation. Consider to change the input type to one of supported at '``/sql-ref-datatypes.html'. +### ERROR_READING_AVRO_UNKNOWN_FINGERPRINT + +SQLSTATE: none assigned + +Error reading avro data -- encountered an unknown fingerprint: ``, not sure what schema to use. This could happen if you registered additional schemas after starting your spark context. + ### EVENT_TIME_IS_NOT_ON_TIMESTAMP_TYPE SQLSTATE: none assigned @@ -520,6 +526,12 @@ Failed preparing of the function `` for call. Please, double check fun Failed parsing struct: ``. +### FAILED_REGISTER_CLASS_WITH_KRYO + +SQLSTATE: none assigned + +Failed to register classes with Kryo. + ### FAILED_RENAME_PATH [SQLSTATE: 42K04](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) @@ -972,6 +984,12 @@ Cannot convert JSON root field to target Spark type. Input schema `` can only contain STRING as a key type for a MAP. +### INVALID_KRYO_SERIALIZER_BUFFER_SIZE + +SQLSTATE: F0000 + +The value of the config "``" must be less than 2048 MiB, but got `` MiB. + ### [INVALID_LAMBDA_FUNCTION_CALL](sql-error-conditions-invalid-lambda-function-call-error-class.html) SQLSTATE: none assigned @@ -1163,6 +1181,12 @@ SQLSTATE: none assigned The join condition `` has the invalid type ``, expected "BOOLEAN". +### KRYO_BUFFER_OVERFLOW + +SQLSTATE: none assigned + +Kryo serialization failed: ``. To avoid this, increase "``" value. + ### LOAD_DATA_PATH_NOT_EXISTS SQLSTATE: none assigned @@ -1563,7 +1587,7 @@ The `` clause may be used at most once per `` operation. [SQLSTATE: 4274K](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) -Cannot invoke function `` because the parameter named `` is required, but the function call did not supply a value. Please update the function call to supply an argument value (either positionally or by name) and retry the query again. +Cannot invoke function `` because the parameter named `` is required, but the function call did not supply a value. Please update the function call to supply an argument value (either positionally at index `` or by name) and retry the query again. ### REQUIRES_SINGLE_PART_NAMESPACE @@ -1586,6 +1610,12 @@ The function `` cannot be found. Verify the spelling and correctnes If you did not qualify the name with a schema and catalog, verify the current_schema() output, or qualify the name with the correct schema and catalog. To tolerate the error on drop use DROP FUNCTION IF EXISTS. +### RULE_ID_NOT_FOUND + +[SQLSTATE: 22023](sql-error-conditions-sqlstates.html#class-22-data-exception) + +Not found an id for the rule name "``". Please modify RuleIdCollection.scala if you are adding a new rule. + ### SCALAR_SUBQUERY_IS_IN_GROUP_BY_OR_AGGREGATE_FUNCTION SQLSTATE: none assigned @@ -1778,7 +1808,7 @@ Parameter `` of function `` requires the `` because it contains positional argument(s) following named argument(s); please rearrange them so the positional arguments come first and then retry the query again. +Cannot invoke function `` because it contains positional argument(s) following the named argument assigned to ``; please rearrange them so the positional arguments come first and then retry the query again. ### UNKNOWN_PROTOBUF_MESSAGE_TYPE diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala deleted file mode 100644 index 3b08b9d62cfce..0000000000000 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.graphx.util - -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} - -import scala.collection.mutable.HashSet - -import org.apache.xbean.asm9.{ClassReader, ClassVisitor, MethodVisitor} -import org.apache.xbean.asm9.Opcodes._ - -import org.apache.spark.util.Utils - -/** - * Includes an utility function to test whether a function accesses a specific attribute - * of an object. - */ -private[graphx] object BytecodeUtils { - - /** - * Test whether the given closure invokes the specified method in the specified class. - */ - def invokedMethod(closure: AnyRef, targetClass: Class[_], targetMethod: String): Boolean = { - if (_invokedMethod(closure.getClass, "apply", targetClass, targetMethod)) { - true - } else { - // look at closures enclosed in this closure - for (f <- closure.getClass.getDeclaredFields - if f.getType.getName.startsWith("scala.Function")) { - f.setAccessible(true) - if (invokedMethod(f.get(closure), targetClass, targetMethod)) { - return true - } - } - false - } - } - - private def _invokedMethod(cls: Class[_], method: String, - targetClass: Class[_], targetMethod: String): Boolean = { - - val seen = new HashSet[(Class[_], String)] - var stack = List[(Class[_], String)]((cls, method)) - - while (stack.nonEmpty) { - val c = stack.head._1 - val m = stack.head._2 - stack = stack.tail - seen.add((c, m)) - val finder = new MethodInvocationFinder(c.getName, m) - getClassReader(c).accept(finder, 0) - for (classMethod <- finder.methodsInvoked) { - if (classMethod._1 == targetClass && classMethod._2 == targetMethod) { - return true - } else if (!seen.contains(classMethod)) { - stack = classMethod :: stack - } - } - } - false - } - - /** - * Get an ASM class reader for a given class from the JAR that loaded it. - */ - private def getClassReader(cls: Class[_]): ClassReader = { - // Copy data over, before delegating to ClassReader - else we can run out of open file handles. - val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" - val resourceStream = cls.getResourceAsStream(className) - // todo: Fixme - continuing with earlier behavior ... - if (resourceStream == null) return new ClassReader(resourceStream) - - val baos = new ByteArrayOutputStream(128) - Utils.copyStream(resourceStream, baos, true) - new ClassReader(new ByteArrayInputStream(baos.toByteArray)) - } - - /** - * Given the class name, return whether we should look into the class or not. This is used to - * skip examining a large quantity of Java or Scala classes that we know for sure wouldn't access - * the closures. Note that the class name is expected in ASM style (i.e. use "/" instead of "."). - */ - private def skipClass(className: String): Boolean = { - val c = className - c.startsWith("java/") || c.startsWith("scala/") || c.startsWith("javax/") - } - - /** - * Find the set of methods invoked by the specified method in the specified class. - * For example, after running the visitor, - * MethodInvocationFinder("spark/graph/Foo", "test") - * its methodsInvoked variable will contain the set of methods invoked directly by - * Foo.test(). Interface invocations are not returned as part of the result set because we cannot - * determine the actual method invoked by inspecting the bytecode. - */ - private class MethodInvocationFinder(className: String, methodName: String) - extends ClassVisitor(ASM9) { - - val methodsInvoked = new HashSet[(Class[_], String)] - - override def visitMethod(access: Int, name: String, desc: String, - sig: String, exceptions: Array[String]): MethodVisitor = { - if (name == methodName) { - new MethodVisitor(ASM9) { - override def visitMethodInsn( - op: Int, owner: String, name: String, desc: String, itf: Boolean): Unit = { - if (op == INVOKEVIRTUAL || op == INVOKESPECIAL || op == INVOKESTATIC) { - if (!skipClass(owner)) { - methodsInvoked.add((Utils.classForName(owner.replace("/", ".")), name)) - } - } - } - } - } else { - null - } - } - } -} diff --git a/pom.xml b/pom.xml index b0d97c2aa0501..76e3596edd430 100644 --- a/pom.xml +++ b/pom.xml @@ -3600,7 +3600,7 @@ scala-2.13 - 2.13.11 + 2.13.8 2.13 @@ -3659,10 +3659,6 @@ --> -Wconf:cat=unused-imports&src=org\/apache\/spark\/graphx\/impl\/VertexPartitionBase.scala:s -Wconf:cat=unused-imports&src=org\/apache\/spark\/graphx\/impl\/VertexPartitionBaseOps.scala:s - - -Wconf:msg=Implicit definition should have explicit type:s diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 14fa43b56725e..8da132f5de3c5 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -71,9 +71,15 @@ object MimaExcludes { // [SPARK-44507][SQL][CONNECT] Move AnalysisException to sql/api. ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.AnalysisException"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.AnalysisException$"), + // [SPARK-44686][CONNECT][SQL] Add the ability to create a RowEncoder in Encoders + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.RowFactory"), // [SPARK-44535][CONNECT][SQL] Move required Streaming API to sql/api ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.GroupStateTimeout"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.OutputMode") + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.OutputMode"), + // [SPARK-44198][CORE] Support propagation of the log level to the executors + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages$SparkAppConfig$"), + // [SPARK-44692][CONNECT][SQL] Move Trigger(s) to sql/api + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.Trigger") ) // Default exclude rules diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index e585d5dd2b25c..bd65d3c4bd4aa 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -286,9 +286,7 @@ object SparkBuild extends PomBuild { // TODO(SPARK-43850): Remove the following suppression rules and remove `import scala.language.higherKinds` // from the corresponding files when Scala 2.12 is no longer supported. "-Wconf:cat=unused-imports&src=org\\/apache\\/spark\\/graphx\\/impl\\/VertexPartitionBase.scala:s", - "-Wconf:cat=unused-imports&src=org\\/apache\\/spark\\/graphx\\/impl\\/VertexPartitionBaseOps.scala:s", - // SPARK-40497 Upgrade Scala to 2.13.11 and suppress `Implicit definition should have explicit type` - "-Wconf:msg=Implicit definition should have explicit type:s" + "-Wconf:cat=unused-imports&src=org\\/apache\\/spark\\/graphx\\/impl\\/VertexPartitionBaseOps.scala:s" ) } } @@ -451,7 +449,7 @@ object SparkBuild extends PomBuild { enable(Unidoc.settings)(spark) /* Sql-api ANTLR generation settings */ - enable(Catalyst.settings)(sqlApi) + enable(SqlApi.settings)(sqlApi) /* Spark SQL Core console settings */ enable(SQL.settings)(sql) @@ -1171,7 +1169,7 @@ object OldDeps { ) } -object Catalyst { +object SqlApi { import com.simplytyped.Antlr4Plugin import com.simplytyped.Antlr4Plugin.autoImport._ diff --git a/python/docs/source/getting_started/index.rst b/python/docs/source/getting_started/index.rst index 3c1c7d80863ce..5f6d306651b92 100644 --- a/python/docs/source/getting_started/index.rst +++ b/python/docs/source/getting_started/index.rst @@ -40,3 +40,4 @@ The list below is the contents of this quickstart page: quickstart_df quickstart_connect quickstart_ps + testing_pyspark diff --git a/python/docs/source/getting_started/testing_pyspark.ipynb b/python/docs/source/getting_started/testing_pyspark.ipynb new file mode 100644 index 0000000000000..268ace04376ba --- /dev/null +++ b/python/docs/source/getting_started/testing_pyspark.ipynb @@ -0,0 +1,485 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "4ee2125b-f889-47e6-9c3d-8bd63a253683", + "metadata": {}, + "source": [ + "# Testing PySpark\n", + "\n", + "This guide is a reference for writing robust tests for PySpark code.\n", + "\n", + "To view the docs for PySpark test utils, see here. To see the code for PySpark built-in test utils, check out the Spark repository here. To see the JIRA board tickets for the PySpark test framework, see here." + ] + }, + { + "cell_type": "markdown", + "id": "0e8ee4b6-9544-45e1-8a91-e71ed8ef8b9d", + "metadata": {}, + "source": [ + "## Build a PySpark Application\n", + "Here is an example for how to start a PySpark application. Feel free to skip to the next section, “Testing your PySpark Application,” if you already have an application you’re ready to test.\n", + "\n", + "First, start your Spark Session." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "9af4a35b-17e8-4e45-816b-34c14c5902f7", + "metadata": {}, + "outputs": [], + "source": [ + "from pyspark.sql import SparkSession \n", + "from pyspark.sql.functions import col \n", + "\n", + "# Create a SparkSession \n", + "spark = SparkSession.builder.appName(\"Testing PySpark Example\").getOrCreate() " + ] + }, + { + "cell_type": "markdown", + "id": "4a4c6efe-91f5-4e18-b4b2-b0401c2368e4", + "metadata": {}, + "source": [ + "Next, create a DataFrame." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "3b483dd8-3a76-41c6-9206-301d7ef314d6", + "metadata": {}, + "outputs": [], + "source": [ + "sample_data = [{\"name\": \"John D.\", \"age\": 30}, \n", + " {\"name\": \"Alice G.\", \"age\": 25}, \n", + " {\"name\": \"Bob T.\", \"age\": 35}, \n", + " {\"name\": \"Eve A.\", \"age\": 28}] \n", + "\n", + "df = spark.createDataFrame(sample_data)" + ] + }, + { + "cell_type": "markdown", + "id": "e0f44333-0e08-470b-9fa2-38f59e3dbd63", + "metadata": {}, + "source": [ + "Now, let’s define and apply a transformation function to our DataFrame." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a6c0b766-af5f-4e1d-acf8-887d7cf0b0b2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+---+--------+\n", + "|age| name|\n", + "+---+--------+\n", + "| 30| John D.|\n", + "| 25|Alice G.|\n", + "| 35| Bob T.|\n", + "| 28| Eve A.|\n", + "+---+--------+\n", + "\n" + ] + } + ], + "source": [ + "from pyspark.sql.functions import col, regexp_replace\n", + "\n", + "# Remove additional spaces in name\n", + "def remove_extra_spaces(df, column_name):\n", + " # Remove extra spaces from the specified column\n", + " df_transformed = df.withColumn(column_name, regexp_replace(col(column_name), \"\\\\s+\", \" \"))\n", + " \n", + " return df_transformed\n", + "\n", + "transformed_df = remove_extra_spaces(df, \"name\")\n", + "\n", + "transformed_df.show()" + ] + }, + { + "cell_type": "markdown", + "id": "530beaa6-aabf-43a1-ad2b-361f267e9608", + "metadata": {}, + "source": [ + "## Testing your PySpark Application\n", + "Now let’s test our PySpark transformation function. \n", + "\n", + "One option is to simply eyeball the resulting DataFrame. However, this can be impractical for large DataFrame or input sizes.\n", + "\n", + "A better way is to write tests. Here are some examples of how we can test our code. The examples below apply for Spark 3.5 and above versions.\n", + "\n", + "Note that these examples are not exhaustive, as there are many other test framework alternatives which you can use instead of `unittest` or `pytest`. The built-in PySpark testing util functions are standalone, meaning they can be compatible with any test framework or CI test pipeline.\n" + ] + }, + { + "cell_type": "markdown", + "id": "d84a9fc1-9768-4af4-bfbf-e832f23334dc", + "metadata": {}, + "source": [ + "### Option 1: Using Only PySpark Built-in Test Utility Functions\n", + "\n", + "For simple ad-hoc validation cases, PySpark testing utils like `assertDataFrameEqual` and `assertSchemaEqual` can be used in a standalone context.\n", + "You could easily test PySpark code in a notebook session. For example, say you want to assert equality between two DataFrames:\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "8e533732-ee40-4cd0-9669-8eb92973908a", + "metadata": {}, + "outputs": [], + "source": [ + "import pyspark.testing\n", + "from pyspark.testing.utils import assertDataFrameEqual\n", + "\n", + "# Example 1\n", + "df1 = spark.createDataFrame(data=[(\"1\", 1000), (\"2\", 3000)], schema=[\"id\", \"amount\"])\n", + "df2 = spark.createDataFrame(data=[(\"1\", 1000), (\"2\", 3000)], schema=[\"id\", \"amount\"])\n", + "assertDataFrameEqual(df1, df2) # pass, DataFrames are identical" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "2d77a6be-1e50-4c1a-8a44-85cf7dcec3f3", + "metadata": {}, + "outputs": [], + "source": [ + "# Example 2\n", + "df1 = spark.createDataFrame(data=[(\"1\", 0.1), (\"2\", 3.23)], schema=[\"id\", \"amount\"])\n", + "df2 = spark.createDataFrame(data=[(\"1\", 0.109), (\"2\", 3.23)], schema=[\"id\", \"amount\"])\n", + "assertDataFrameEqual(df1, df2, rtol=1e-1) # pass, DataFrames are approx equal by rtol" + ] + }, + { + "cell_type": "markdown", + "id": "76ade5f2-4a1f-4601-9d2a-80da9da950ff", + "metadata": {}, + "source": [ + "You can also simply compare two DataFrame schemas:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "74393af5-40fb-4d04-87cb-265971ffe6d0", + "metadata": {}, + "outputs": [], + "source": [ + "from pyspark.testing.utils import assertSchemaEqual\n", + "from pyspark.sql.types import StructType, StructField, ArrayType, DoubleType\n", + "\n", + "s1 = StructType([StructField(\"names\", ArrayType(DoubleType(), True), True)])\n", + "s2 = StructType([StructField(\"names\", ArrayType(DoubleType(), True), True)])\n", + "\n", + "assertSchemaEqual(s1, s2) # pass, schemas are identical" + ] + }, + { + "cell_type": "markdown", + "id": "c67be105-f6b1-4083-ad11-9e819331eae8", + "metadata": {}, + "source": [ + "### Option 2: Using [Unit Test](https://docs.python.org/3/library/unittest.html)\n", + "For more complex testing scenarios, you may want to use a testing framework.\n", + "\n", + "One of the most popular testing framework options is unit tests. Let’s walk through how you can use the built-in Python `unittest` library to write PySpark tests. For more information about the `unittest` library, see here: https://docs.python.org/3/library/unittest.html. \n", + "\n", + "First, you will need a Spark session. You can use the `@classmethod` decorator from the `unittest` package to take care of setting up and tearing down a Spark session." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "54093761-0b49-4aee-baec-2d29bcf13f9f", + "metadata": {}, + "outputs": [], + "source": [ + "import unittest\n", + "\n", + "class PySparkTestCase(unittest.TestCase):\n", + " @classmethod\n", + " def setUpClass(cls):\n", + " cls.spark = SparkSession.builder.appName(\"Testing PySpark Example\").getOrCreate() \n", + "\n", + " \n", + " @classmethod\n", + " def tearDownClass(cls):\n", + " cls.spark.stop()" + ] + }, + { + "cell_type": "markdown", + "id": "3de27500-8526-412e-bf09-6927a760c5d7", + "metadata": {}, + "source": [ + "Now let’s write a `unittest` class." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "34feb5e1-944f-4f6b-9c5f-3b0bf68c7d05", + "metadata": {}, + "outputs": [], + "source": [ + "from pyspark.testing.utils import assertDataFrameEqual\n", + "\n", + "class TestTranformation(PySparkTestCase):\n", + " def test_single_space(self):\n", + " sample_data = [{\"name\": \"John D.\", \"age\": 30}, \n", + " {\"name\": \"Alice G.\", \"age\": 25}, \n", + " {\"name\": \"Bob T.\", \"age\": 35}, \n", + " {\"name\": \"Eve A.\", \"age\": 28}] \n", + " \n", + " # Create a Spark DataFrame\n", + " original_df = spark.createDataFrame(sample_data)\n", + " \n", + " # Apply the transformation function from before\n", + " transformed_df = remove_extra_spaces(original_df, \"name\")\n", + " \n", + " expected_data = [{\"name\": \"John D.\", \"age\": 30}, \n", + " {\"name\": \"Alice G.\", \"age\": 25}, \n", + " {\"name\": \"Bob T.\", \"age\": 35}, \n", + " {\"name\": \"Eve A.\", \"age\": 28}]\n", + " \n", + " expected_df = spark.createDataFrame(expected_data)\n", + " \n", + " assertDataFrameEqual(transformed_df, expected_df)\n" + ] + }, + { + "cell_type": "markdown", + "id": "319a690f-71bd-4886-bd3a-424e866525c2", + "metadata": {}, + "source": [ + "When run, `unittest` will pick up all functions with a name beginning with “test.”" + ] + }, + { + "cell_type": "markdown", + "id": "7d79e53d-cc1e-4fdf-a069-478337bed83d", + "metadata": {}, + "source": [ + "### Option 3: Using [Pytest](https://docs.pytest.org/en/7.1.x/contents.html)\n", + "\n", + "We can also write our tests with `pytest`, which is one of the most popular Python testing frameworks. For more information about `pytest`, see the docs here: https://docs.pytest.org/en/7.1.x/contents.html.\n", + "\n", + "Using a `pytest` fixture allows us to share a spark session across tests, tearing it down when the tests are complete." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "60a4f304-1911-4b4d-8ed9-00ecc8b0890b", + "metadata": {}, + "outputs": [], + "source": [ + "import pytest\n", + "\n", + "@pytest.fixture\n", + "def spark_fixture():\n", + " spark = SparkSession.builder.appName(\"Testing PySpark Example\").getOrCreate()\n", + " yield spark" + ] + }, + { + "cell_type": "markdown", + "id": "fcb4e26a-9bfc-48a5-8aca-538697d66642", + "metadata": {}, + "source": [ + "We can then define our tests like this:" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "fa5db3a1-7305-44b7-ab84-f5ed55fd2ba9", + "metadata": {}, + "outputs": [], + "source": [ + "import pytest\n", + "from pyspark.testing.utils import assertDataFrameEqual\n", + "\n", + "def test_single_space(spark_fixture):\n", + " sample_data = [{\"name\": \"John D.\", \"age\": 30}, \n", + " {\"name\": \"Alice G.\", \"age\": 25}, \n", + " {\"name\": \"Bob T.\", \"age\": 35}, \n", + " {\"name\": \"Eve A.\", \"age\": 28}] \n", + " \n", + " # Create a Spark DataFrame\n", + " original_df = spark.createDataFrame(sample_data)\n", + " \n", + " # Apply the transformation function from before\n", + " transformed_df = remove_extra_spaces(original_df, \"name\")\n", + " \n", + " expected_data = [{\"name\": \"John D.\", \"age\": 30}, \n", + " {\"name\": \"Alice G.\", \"age\": 25}, \n", + " {\"name\": \"Bob T.\", \"age\": 35}, \n", + " {\"name\": \"Eve A.\", \"age\": 28}]\n", + " \n", + " expected_df = spark.createDataFrame(expected_data)\n", + "\n", + " assertDataFrameEqual(transformed_df, expected_df)" + ] + }, + { + "cell_type": "markdown", + "id": "0fc3f394-3260-4e42-82cf-1a7edc859151", + "metadata": {}, + "source": [ + "When you run your test file with the `pytest` command, it will pick up all functions that have their name beginning with “test.”" + ] + }, + { + "cell_type": "markdown", + "id": "d8f50eee-5d0b-4719-b505-1b3ff05c16e8", + "metadata": {}, + "source": [ + "## Putting It All Together!\n", + "\n", + "Let’s see all the steps together, in a Unit Test example." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "a2ea9dec-0ac0-4c23-8770-d6cc226d2e97", + "metadata": {}, + "outputs": [], + "source": [ + "# pkg/etl.py\n", + "import unittest\n", + "\n", + "from pyspark.sql import SparkSession \n", + "from pyspark.sql.functions import col\n", + "from pyspark.sql.functions import regexp_replace\n", + "from pyspark.testing.utils import assertDataFrameEqual\n", + "\n", + "# Create a SparkSession \n", + "spark = SparkSession.builder.appName(\"Sample PySpark ETL\").getOrCreate() \n", + "\n", + "sample_data = [{\"name\": \"John D.\", \"age\": 30}, \n", + " {\"name\": \"Alice G.\", \"age\": 25}, \n", + " {\"name\": \"Bob T.\", \"age\": 35}, \n", + " {\"name\": \"Eve A.\", \"age\": 28}] \n", + "\n", + "df = spark.createDataFrame(sample_data)\n", + "\n", + "# Define DataFrame transformation function\n", + "def remove_extra_spaces(df, column_name):\n", + " # Remove extra spaces from the specified column using regexp_replace\n", + " df_transformed = df.withColumn(column_name, regexp_replace(col(column_name), \"\\\\s+\", \" \"))\n", + "\n", + " return df_transformed" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "248aede2-feb9-4828-bd9c-8e25e6b194ab", + "metadata": {}, + "outputs": [], + "source": [ + "# pkg/test_etl.py\n", + "import unittest\n", + "\n", + "from pyspark.sql import SparkSession \n", + "\n", + "# Define unit test base class\n", + "class PySparkTestCase(unittest.TestCase):\n", + " @classmethod\n", + " def setUpClass(cls):\n", + " cls.spark = SparkSession.builder.appName(\"Sample PySpark ETL\").getOrCreate() \n", + "\n", + " @classmethod\n", + " def tearDownClass(cls):\n", + " cls.spark.stop()\n", + " \n", + "# Define unit test\n", + "class TestTranformation(PySparkTestCase):\n", + " def test_single_space(self):\n", + " sample_data = [{\"name\": \"John D.\", \"age\": 30}, \n", + " {\"name\": \"Alice G.\", \"age\": 25}, \n", + " {\"name\": \"Bob T.\", \"age\": 35}, \n", + " {\"name\": \"Eve A.\", \"age\": 28}] \n", + " \n", + " # Create a Spark DataFrame\n", + " original_df = spark.createDataFrame(sample_data)\n", + " \n", + " # Apply the transformation function from before\n", + " transformed_df = remove_extra_spaces(original_df, \"name\")\n", + " \n", + " expected_data = [{\"name\": \"John D.\", \"age\": 30}, \n", + " {\"name\": \"Alice G.\", \"age\": 25}, \n", + " {\"name\": \"Bob T.\", \"age\": 35}, \n", + " {\"name\": \"Eve A.\", \"age\": 28}]\n", + " \n", + " expected_df = spark.createDataFrame(expected_data)\n", + " \n", + " assertDataFrameEqual(transformed_df, expected_df)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "a77df5b2-f32e-4d8c-a64b-0078dfa21217", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ran 1 test in 1.734s\n", + "\n", + "OK\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "unittest.main(argv=[''], verbosity=0, exit=False)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "jupyter-oss-env", + "language": "python", + "name": "jupyter-oss-env" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/python/docs/source/migration_guide/pyspark_upgrade.rst b/python/docs/source/migration_guide/pyspark_upgrade.rst index f04da62a7051d..4fe8b49c2a9bf 100644 --- a/python/docs/source/migration_guide/pyspark_upgrade.rst +++ b/python/docs/source/migration_guide/pyspark_upgrade.rst @@ -22,6 +22,14 @@ Upgrading PySpark Upgrading from PySpark 3.5 to 4.0 --------------------------------- +* In Spark 4.0, ``Int64Index`` and ``Float64Index`` have been removed from pandas API on Spark, ``Index`` should be used directly. +* In Spark 4.0, ``DataFrame.iteritems`` has been removed from pandas API on Spark, use ``DataFrame.items`` instead. +* In Spark 4.0, ``Series.iteritems`` has been removed from pandas API on Spark, use ``Series.items`` instead. +* In Spark 4.0, ``DataFrame.append`` has been removed from pandas API on Spark, use ``ps.concat`` instead. +* In Spark 4.0, ``Series.append`` has been removed from pandas API on Spark, use ``ps.concat`` instead. +* In Spark 4.0, ``DataFrame.mad`` has been removed from pandas API on Spark. +* In Spark 4.0, ``Series.mad`` has been removed from pandas API on Spark. +* In Spark 4.0, ``na_sentinel`` parameter from ``Index.factorize`` and `Series.factorize`` has been removed from pandas API on Spark, use ``use_na_sentinel`` instead. * In Spark 4.0, the default value of ``regex`` parameter for ``Series.str.replace`` has been changed from ``True`` to ``False`` from pandas API on Spark. Additionally, a single character ``pat`` with ``regex=True`` is now treated as a regular expression instead of a string literal. diff --git a/python/docs/source/reference/pyspark.pandas/frame.rst b/python/docs/source/reference/pyspark.pandas/frame.rst index a8d114187b94b..5f839a803d78a 100644 --- a/python/docs/source/reference/pyspark.pandas/frame.rst +++ b/python/docs/source/reference/pyspark.pandas/frame.rst @@ -79,7 +79,6 @@ Indexing, iteration DataFrame.iloc DataFrame.insert DataFrame.items - DataFrame.iteritems DataFrame.iterrows DataFrame.itertuples DataFrame.keys @@ -155,7 +154,6 @@ Computations / Descriptive Stats DataFrame.ewm DataFrame.kurt DataFrame.kurtosis - DataFrame.mad DataFrame.max DataFrame.mean DataFrame.min @@ -252,7 +250,6 @@ Combining / joining / merging .. autosummary:: :toctree: api/ - DataFrame.append DataFrame.assign DataFrame.merge DataFrame.join diff --git a/python/docs/source/reference/pyspark.pandas/groupby.rst b/python/docs/source/reference/pyspark.pandas/groupby.rst index da1579fd72350..e71e81c56dd3e 100644 --- a/python/docs/source/reference/pyspark.pandas/groupby.rst +++ b/python/docs/source/reference/pyspark.pandas/groupby.rst @@ -68,7 +68,6 @@ Computations / Descriptive Stats GroupBy.filter GroupBy.first GroupBy.last - GroupBy.mad GroupBy.max GroupBy.mean GroupBy.median diff --git a/python/docs/source/reference/pyspark.pandas/indexing.rst b/python/docs/source/reference/pyspark.pandas/indexing.rst index 15539fa226633..70d463c052a03 100644 --- a/python/docs/source/reference/pyspark.pandas/indexing.rst +++ b/python/docs/source/reference/pyspark.pandas/indexing.rst @@ -166,16 +166,6 @@ Selecting Index.asof Index.isin -.. _api.numeric: - -Numeric Index -------------- -.. autosummary:: - :toctree: api/ - - Int64Index - Float64Index - .. _api.categorical: CategoricalIndex diff --git a/python/docs/source/reference/pyspark.pandas/series.rst b/python/docs/source/reference/pyspark.pandas/series.rst index a0119593f96ae..552acec096f69 100644 --- a/python/docs/source/reference/pyspark.pandas/series.rst +++ b/python/docs/source/reference/pyspark.pandas/series.rst @@ -70,7 +70,6 @@ Indexing, iteration Series.keys Series.pop Series.items - Series.iteritems Series.item Series.xs Series.get @@ -148,7 +147,6 @@ Computations / Descriptive Stats Series.ewm Series.filter Series.kurt - Series.mad Series.max Series.mean Series.min @@ -247,7 +245,6 @@ Combining / joining / merging .. autosummary:: :toctree: api/ - Series.append Series.compare Series.replace Series.update diff --git a/python/docs/source/reference/pyspark.sql/spark_session.rst b/python/docs/source/reference/pyspark.sql/spark_session.rst index c16ca4f162f5c..f25dbab5f6b9b 100644 --- a/python/docs/source/reference/pyspark.sql/spark_session.rst +++ b/python/docs/source/reference/pyspark.sql/spark_session.rst @@ -28,6 +28,7 @@ See also :class:`SparkSession`. .. autosummary:: :toctree: api/ + SparkSession.active SparkSession.builder.appName SparkSession.builder.config SparkSession.builder.enableHiveSupport diff --git a/python/pyspark/cloudpickle/cloudpickle_fast.py b/python/pyspark/cloudpickle/cloudpickle_fast.py index 63aaffa096b2c..ee1f4b8ee967e 100644 --- a/python/pyspark/cloudpickle/cloudpickle_fast.py +++ b/python/pyspark/cloudpickle/cloudpickle_fast.py @@ -631,7 +631,7 @@ def dump(self, obj): try: return Pickler.dump(self, obj) except RuntimeError as e: - if "recursion" in e.args[0]: + if len(e.args) > 0 and "recursion" in e.args[0]: msg = ( "Could not pickle object as excessively deep recursion " "required." diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index d6f093246dacd..bc32afeb87a9f 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -617,6 +617,11 @@ "Argument `` should be a WindowSpec, got ." ] }, + "NO_ACTIVE_OR_DEFAULT_SESSION" : { + "message" : [ + "No active or default Spark session found. Please create a new Spark session before running the code." + ] + }, "NO_ACTIVE_SESSION" : { "message" : [ "No active Spark session found. Please create a new Spark session before running the code." @@ -738,6 +743,16 @@ "User defined table function encountered an error in the '' method: " ] }, + "UDTF_INVALID_OUTPUT_ROW_TYPE" : { + "message" : [ + "The type of an individual output row in the UDTF is invalid. Each row should be a tuple, list, or dict, but got ''. Please make sure that the output rows are of the correct type." + ] + }, + "UDTF_RETURN_NOT_ITERABLE" : { + "message" : [ + "The return value of the UDTF is invalid. It should be an iterable (e.g., generator or list), but got ''. Please make sure that the UDTF returns one of these types." + ] + }, "UDTF_RETURN_SCHEMA_MISMATCH" : { "message" : [ "The number of columns in the result does not match the specified schema. Expected column count: , Actual column count: . Please make sure the values returned by the function have the same number of columns as specified in the output schema." @@ -748,6 +763,11 @@ "Mismatch in return type for the UDTF ''. Expected a 'StructType', but got ''. Please ensure the return type is a correctly formatted StructType." ] }, + "UDTF_SERIALIZATION_ERROR" : { + "message" : [ + "Cannot serialize the UDTF '': " + ] + }, "UNEXPECTED_RESPONSE_FROM_SERVER" : { "message" : [ "Unexpected response from iterator server." diff --git a/python/pyspark/ml/connect/io_utils.py b/python/pyspark/ml/connect/io_utils.py index 9a963086aaf45..a09a244862c58 100644 --- a/python/pyspark/ml/connect/io_utils.py +++ b/python/pyspark/ml/connect/io_utils.py @@ -23,7 +23,7 @@ from urllib.parse import urlparse from typing import Any, Dict, List from pyspark.ml.base import Params -from pyspark.ml.util import _get_active_session +from pyspark.sql import SparkSession from pyspark.sql.utils import is_remote @@ -34,7 +34,7 @@ def _copy_file_from_local_to_fs(local_path: str, dest_path: str) -> None: - session = _get_active_session(is_remote()) + session = SparkSession.active() if is_remote(): session.copyFromLocalToFs(local_path, dest_path) else: @@ -228,7 +228,7 @@ def save(self, path: str, *, overwrite: bool = False) -> None: .. versionadded:: 3.5.0 """ - session = _get_active_session(is_remote()) + session = SparkSession.active() path_exist = True try: session.read.format("binaryFile").load(path).head() @@ -256,7 +256,7 @@ def load(cls, path: str) -> "Params": .. versionadded:: 3.5.0 """ - session = _get_active_session(is_remote()) + session = SparkSession.active() tmp_local_dir = tempfile.mkdtemp(prefix="pyspark_ml_model_") try: diff --git a/python/pyspark/ml/connect/tuning.py b/python/pyspark/ml/connect/tuning.py index 6d539933e1d69..c22c31e84e8de 100644 --- a/python/pyspark/ml/connect/tuning.py +++ b/python/pyspark/ml/connect/tuning.py @@ -178,11 +178,12 @@ def _parallelFitTasks( def get_single_task(index: int, param_map: Any) -> Callable[[], Tuple[int, float]]: def single_task() -> Tuple[int, float]: - # Active session is thread-local variable, in background thread the active session - # is not set, the following line sets it as the main thread active session. - active_session._jvm.SparkSession.setActiveSession( # type: ignore[union-attr] - active_session._jsparkSession # type: ignore[union-attr] - ) + if not is_remote(): + # Active session is thread-local variable, in background thread the active session + # is not set, the following line sets it as the main thread active session. + active_session._jvm.SparkSession.setActiveSession( # type: ignore[union-attr] + active_session._jsparkSession # type: ignore[union-attr] + ) model = estimator.fit(train, param_map) metric = evaluator.evaluate( diff --git a/python/pyspark/ml/tests/connect/test_connect_classification.py b/python/pyspark/ml/tests/connect/test_connect_classification.py index 6ad47322234c5..f3e621c19f0f0 100644 --- a/python/pyspark/ml/tests/connect/test_connect_classification.py +++ b/python/pyspark/ml/tests/connect/test_connect_classification.py @@ -20,7 +20,14 @@ from pyspark.sql import SparkSession from pyspark.ml.tests.connect.test_legacy_mode_classification import ClassificationTestsMixin +have_torch = True +try: + import torch # noqa: F401 +except ImportError: + have_torch = False + +@unittest.skipIf(not have_torch, "torch is required") class ClassificationTestsOnConnect(ClassificationTestsMixin, unittest.TestCase): def setUp(self) -> None: self.spark = ( diff --git a/python/pyspark/ml/torch/distributor.py b/python/pyspark/ml/torch/distributor.py index 2056803d61cf4..a4e79b1dcc10b 100644 --- a/python/pyspark/ml/torch/distributor.py +++ b/python/pyspark/ml/torch/distributor.py @@ -49,7 +49,6 @@ LogStreamingServer, ) from pyspark.ml.dl_util import FunctionPickler -from pyspark.ml.util import _get_active_session def _get_resources(session: SparkSession) -> Dict[str, ResourceInformation]: @@ -165,7 +164,7 @@ def __init__( from pyspark.sql.utils import is_remote self.is_remote = is_remote() - self.spark = _get_active_session(self.is_remote) + self.spark = SparkSession.active() # indicate whether the server side is local mode self.is_spark_local_master = False diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 2c90ff3cb7b69..64676947017d0 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -747,16 +747,3 @@ def wrapped(*args: Any, **kwargs: Any) -> Any: return f(*args, **kwargs) return cast(FuncT, wrapped) - - -def _get_active_session(is_remote: bool) -> SparkSession: - if not is_remote: - spark = SparkSession.getActiveSession() - else: - import pyspark.sql.connect.session - - spark = pyspark.sql.connect.session._active_spark_session # type: ignore[assignment] - - if spark is None: - raise RuntimeError("An active SparkSession is required for the distributor.") - return spark diff --git a/python/pyspark/pandas/__init__.py b/python/pyspark/pandas/__init__.py index 980aeab2bee87..d8ce385639cec 100644 --- a/python/pyspark/pandas/__init__.py +++ b/python/pyspark/pandas/__init__.py @@ -61,7 +61,6 @@ from pyspark.pandas.indexes.category import CategoricalIndex from pyspark.pandas.indexes.datetimes import DatetimeIndex from pyspark.pandas.indexes.multi import MultiIndex -from pyspark.pandas.indexes.numeric import Float64Index, Int64Index from pyspark.pandas.indexes.timedelta import TimedeltaIndex from pyspark.pandas.series import Series from pyspark.pandas.groupby import NamedAgg @@ -77,8 +76,6 @@ "Series", "Index", "MultiIndex", - "Int64Index", - "Float64Index", "CategoricalIndex", "DatetimeIndex", "TimedeltaIndex", diff --git a/python/pyspark/pandas/base.py b/python/pyspark/pandas/base.py index e005fd19b3009..0685af769872a 100644 --- a/python/pyspark/pandas/base.py +++ b/python/pyspark/pandas/base.py @@ -904,8 +904,8 @@ def astype(self: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike: 1 2 dtype: int64 - >>> ser.rename("a").to_frame().set_index("a").index.astype('int64') # doctest: +SKIP - Int64Index([1, 2], dtype='int64', name='a') + >>> ser.rename("a").to_frame().set_index("a").index.astype('int64') + Index([1, 2], dtype='int64', name='a') """ return self._dtype_op.astype(self, dtype) @@ -1247,8 +1247,8 @@ def shift( 4 23 Name: Col2, dtype: int64 - >>> df.index.shift(periods=3, fill_value=0) # doctest: +SKIP - Int64Index([0, 0, 0, 0, 1], dtype='int64') + >>> df.index.shift(periods=3, fill_value=0) + Index([0, 0, 0, 0, 1], dtype='int64') """ return self._shift(periods, fill_value).spark.analyzed @@ -1341,8 +1341,8 @@ def value_counts( For Index >>> idx = ps.Index([3, 1, 2, 3, 4, np.nan]) - >>> idx # doctest: +SKIP - Float64Index([3.0, 1.0, 2.0, 3.0, 4.0, nan], dtype='float64') + >>> idx + Index([3.0, 1.0, 2.0, 3.0, 4.0, nan], dtype='float64') >>> idx.value_counts().sort_index() 1.0 1 @@ -1511,8 +1511,8 @@ def nunique(self, dropna: bool = True, approx: bool = False, rsd: float = 0.05) 3 >>> idx = ps.Index([1, 1, 2, None]) - >>> idx # doctest: +SKIP - Float64Index([1.0, 1.0, 2.0, nan], dtype='float64') + >>> idx + Index([1.0, 1.0, 2.0, nan], dtype='float64') >>> idx.nunique() 2 @@ -1586,11 +1586,11 @@ def take(self: IndexOpsLike, indices: Sequence[int]) -> IndexOpsLike: Index >>> psidx = ps.Index([100, 200, 300, 400, 500]) - >>> psidx # doctest: +SKIP - Int64Index([100, 200, 300, 400, 500], dtype='int64') + >>> psidx + Index([100, 200, 300, 400, 500], dtype='int64') - >>> psidx.take([0, 2, 4]).sort_values() # doctest: +SKIP - Int64Index([100, 300, 500], dtype='int64') + >>> psidx.take([0, 2, 4]).sort_values() + Index([100, 300, 500], dtype='int64') MultiIndex @@ -1614,7 +1614,7 @@ def take(self: IndexOpsLike, indices: Sequence[int]) -> IndexOpsLike: return cast(IndexOpsLike, self._psdf.iloc[indices].index) def factorize( - self: IndexOpsLike, sort: bool = True, na_sentinel: Optional[int] = -1 + self: IndexOpsLike, sort: bool = True, use_na_sentinel: bool = True ) -> Tuple[IndexOpsLike, pd.Index]: """ Encode the object as an enumerated type or categorical variable. @@ -1625,11 +1625,11 @@ def factorize( Parameters ---------- sort : bool, default True - na_sentinel : int or None, default -1 - Value to mark "not found". If None, will not drop the NaN - from the uniques of the values. - - .. deprecated:: 3.4.0 + use_na_sentinel : bool, default True + If True, the sentinel -1 will be used for NaN values, effectively assigning them + a distinct category. If False, NaN values will be encoded as non-negative integers, + treating them as unique categories in the encoding process and retaining them in the + set of unique categories in the data. Returns ------- @@ -1658,7 +1658,7 @@ def factorize( >>> uniques Index(['a', 'b', 'c'], dtype='object') - >>> codes, uniques = psser.factorize(na_sentinel=None) + >>> codes, uniques = psser.factorize(use_na_sentinel=False) >>> codes 0 1 1 3 @@ -1669,30 +1669,19 @@ def factorize( >>> uniques Index(['a', 'b', 'c', None], dtype='object') - >>> codes, uniques = psser.factorize(na_sentinel=-2) - >>> codes - 0 1 - 1 -2 - 2 0 - 3 2 - 4 1 - dtype: int32 - >>> uniques - Index(['a', 'b', 'c'], dtype='object') - For Index: >>> psidx = ps.Index(['b', None, 'a', 'c', 'b']) >>> codes, uniques = psidx.factorize() - >>> codes # doctest: +SKIP - Int64Index([1, -1, 0, 2, 1], dtype='int64') + >>> codes + Index([1, -1, 0, 2, 1], dtype='int32') >>> uniques Index(['a', 'b', 'c'], dtype='object') """ from pyspark.pandas.series import first_series - assert (na_sentinel is None) or isinstance(na_sentinel, int) assert sort is True + use_na_sentinel = -1 if use_na_sentinel else False # type: ignore[assignment] warnings.warn( "Argument `na_sentinel` will be removed in 4.0.0.", @@ -1716,7 +1705,7 @@ def factorize( scol = map_scol[self.spark.column] codes, uniques = self._with_new_scol( scol.alias(self._internal.data_spark_column_names[0]) - ).factorize(na_sentinel=na_sentinel) + ).factorize(use_na_sentinel=use_na_sentinel) return codes, uniques.astype(self.dtype) uniq_sdf = self._internal.spark_frame.select(self.spark.column).distinct() @@ -1743,13 +1732,13 @@ def factorize( # Constructs `unique_to_code` mapping non-na unique to code unique_to_code = {} - if na_sentinel is not None: - na_sentinel_code = na_sentinel + if use_na_sentinel: + na_sentinel_code = use_na_sentinel code = 0 for unique in uniques_list: if pd.isna(unique): - if na_sentinel is None: - na_sentinel_code = code + if not use_na_sentinel: + na_sentinel_code = code # type: ignore[assignment] else: unique_to_code[unique] = code code += 1 @@ -1767,7 +1756,7 @@ def factorize( codes = self._with_new_scol(new_scol.alias(self._internal.data_spark_column_names[0])) - if na_sentinel is not None: + if use_na_sentinel: # Drops the NaN from the uniques of the values uniques_list = [x for x in uniques_list if not pd.isna(x)] diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index d8a3f812c33ab..72d4a88b69203 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -734,8 +734,8 @@ def axes(self) -> List: -------- >>> df = ps.DataFrame({'col1': [1, 2], 'col2': [3, 4]}) - >>> df.axes # doctest: +SKIP - [Int64Index([0, 1], dtype='int64'), Index(['col1', 'col2'], dtype='object')] + >>> df.axes + [Index([0, 1], dtype='int64'), Index(['col1', 'col2'], dtype='object')] """ return [self.index, self.columns] @@ -1880,11 +1880,9 @@ def items(self) -> Iterator[Tuple[Name, "Series"]]: polar bear 22000 koala marsupial 80000 - >>> for label, content in df.iteritems(): + >>> for label, content in df.items(): ... print('label:', label) ... print('content:', content.to_string()) - ... - ... # doctest: +SKIP label: species content: panda bear polar bear @@ -2057,20 +2055,6 @@ def extract_kv_from_spark_row(row: Row) -> Tuple[Name, Any]: ): yield tuple(([k] if index else []) + list(v)) - def iteritems(self) -> Iterator[Tuple[Name, "Series"]]: - """ - This is an alias of ``items``. - - .. deprecated:: 3.4.0 - iteritems is deprecated and will be removed in a future version. - Use .items instead. - """ - warnings.warn( - "Deprecated in 3.4.0, and will be removed in 4.0.0. Use DataFrame.items instead.", - FutureWarning, - ) - return self.items() - def to_clipboard(self, excel: bool = True, sep: Optional[str] = None, **kwargs: Any) -> None: """ Copy object to the system clipboard. @@ -8723,8 +8707,8 @@ def join( the original DataFrame’s index in the result unlike pandas. >>> join_psdf = psdf1.join(psdf2.set_index('key'), on='key') - >>> join_psdf.index # doctest: +SKIP - Int64Index([0, 1, 2, 3], dtype='int64') + >>> join_psdf.index + Index([0, 1, 2, 3], dtype='int64') """ if isinstance(right, ps.Series): common = list(self.columns.intersection([right.name])) @@ -8837,91 +8821,6 @@ def combine_first(self, other: "DataFrame") -> "DataFrame": ) return DataFrame(internal) - def append( - self, - other: "DataFrame", - ignore_index: bool = False, - verify_integrity: bool = False, - sort: bool = False, - ) -> "DataFrame": - """ - Append rows of other to the end of caller, returning a new object. - - Columns in other that are not in the caller are added as new columns. - - .. deprecated:: 3.4.0 - - Parameters - ---------- - other : DataFrame or Series/dict-like object, or list of these - The data to append. - - ignore_index : boolean, default False - If True, do not use the index labels. - - verify_integrity : boolean, default False - If True, raise ValueError on creating index with duplicates. - - sort : boolean, default False - Currently not supported. - - Returns - ------- - appended : DataFrame - - Examples - -------- - >>> df = ps.DataFrame([[1, 2], [3, 4]], columns=list('AB')) - - >>> df.append(df) - A B - 0 1 2 - 1 3 4 - 0 1 2 - 1 3 4 - - >>> df.append(df, ignore_index=True) - A B - 0 1 2 - 1 3 4 - 2 1 2 - 3 3 4 - """ - warnings.warn( - "The DataFrame.append method is deprecated " - "and will be removed in 4.0.0. " - "Use pyspark.pandas.concat instead.", - FutureWarning, - ) - if isinstance(other, ps.Series): - raise TypeError("DataFrames.append() does not support appending Series to DataFrames") - if sort: - raise NotImplementedError("The 'sort' parameter is currently not supported") - - if not ignore_index: - index_scols = self._internal.index_spark_columns - if len(index_scols) != other._internal.index_level: - raise ValueError("Both DataFrames have to have the same number of index levels") - - if ( - verify_integrity - and len(index_scols) > 0 - and ( - self._internal.spark_frame.select(index_scols) - .intersect( - other._internal.spark_frame.select(other._internal.index_spark_columns) - ) - .count() - ) - > 0 - ): - raise ValueError("Indices have overlapping values") - - # Lazy import to avoid circular dependency issues - from pyspark.pandas.namespace import concat - - return cast(DataFrame, concat([self, other], ignore_index=ignore_index)) - # TODO: add 'filter_func' and 'errors' parameter def update(self, other: "DataFrame", join: str = "left", overwrite: bool = True) -> None: """ @@ -12719,107 +12618,6 @@ def explode(self, column: Name, ignore_index: bool = False) -> "DataFrame": result_df: DataFrame = DataFrame(internal) return result_df.reset_index(drop=True) if ignore_index else result_df - def mad(self, axis: Axis = 0) -> "Series": - """ - Return the mean absolute deviation of values. - - .. deprecated:: 3.4.0 - - Parameters - ---------- - axis : {index (0), columns (1)} - Axis for the function to be applied on. - - Examples - -------- - >>> df = ps.DataFrame({'a': [1, 2, 3, np.nan], 'b': [0.1, 0.2, 0.3, np.nan]}, - ... columns=['a', 'b']) - - >>> df.mad() - a 0.666667 - b 0.066667 - dtype: float64 - - >>> df.mad(axis=1) # doctest: +SKIP - 0 0.45 - 1 0.90 - 2 1.35 - 3 NaN - dtype: float64 - """ - warnings.warn( - "The 'mad' method is deprecated and will be removed in 4.0.0. " - "To compute the same result, you may do `(df - df.mean()).abs().mean()`.", - FutureWarning, - ) - from pyspark.pandas.series import first_series - - axis = validate_axis(axis) - - if axis == 0: - - def get_spark_column(psdf: DataFrame, label: Label) -> PySparkColumn: - scol = psdf._internal.spark_column_for(label) - col_type = psdf._internal.spark_type_for(label) - - if isinstance(col_type, BooleanType): - scol = scol.cast("integer") - - return scol - - new_column_labels: List[Label] = [] - for label in self._internal.column_labels: - # Filtering out only columns of numeric and boolean type column. - dtype = self._psser_for(label).spark.data_type - if isinstance(dtype, (NumericType, BooleanType)): - new_column_labels.append(label) - - new_columns = [ - F.avg(get_spark_column(self, label)).alias(name_like_string(label)) - for label in new_column_labels - ] - - mean_data = self._internal.spark_frame.select(*new_columns).first() - - new_columns = [ - F.avg( - F.abs(get_spark_column(self, label) - mean_data[name_like_string(label)]) - ).alias(name_like_string(label)) - for label in new_column_labels - ] - - sdf = self._internal.spark_frame.select( - *[F.lit(None).cast(StringType()).alias(SPARK_DEFAULT_INDEX_NAME)], *new_columns - ) - - # The data is expected to be small so it's fine to transpose/use the default index. - with ps.option_context("compute.max_rows", 1): - internal = InternalFrame( - spark_frame=sdf, - index_spark_columns=[scol_for(sdf, SPARK_DEFAULT_INDEX_NAME)], - column_labels=new_column_labels, - column_label_names=self._internal.column_label_names, - ) - return first_series(DataFrame(internal).transpose()) - - else: - - @pandas_udf(returnType=DoubleType()) # type: ignore[call-overload] - def calculate_columns_axis(*cols: pd.Series) -> pd.Series: - return pd.concat(cols, axis=1).mad(axis=1) - - internal = self._internal.copy( - column_labels=[None], - data_spark_columns=[ - calculate_columns_axis(*self._internal.data_spark_columns).alias( - SPARK_DEFAULT_SERIES_NAME - ) - ], - data_fields=[None], - column_label_names=None, - ) - return first_series(DataFrame(internal)) - def mode(self, axis: Axis = 0, numeric_only: bool = False, dropna: bool = True) -> "DataFrame": """ Get the mode(s) of each element along the selected axis. diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py index 663a635668ebf..2de328177937f 100644 --- a/python/pyspark/pandas/groupby.py +++ b/python/pyspark/pandas/groupby.py @@ -991,87 +991,6 @@ def skew(self) -> FrameLike: bool_to_numeric=True, ) - # TODO: 'axis', 'skipna', 'level' parameter should be implemented. - def mad(self) -> FrameLike: - """ - Compute mean absolute deviation of groups, excluding missing values. - - .. versionadded:: 3.4.0 - - .. deprecated:: 3.4.0 - - Examples - -------- - >>> df = ps.DataFrame({"A": [1, 2, 1, 1], "B": [True, False, False, True], - ... "C": [3, 4, 3, 4], "D": ["a", "b", "b", "a"]}) - - >>> df.groupby("A").mad() - B C - A - 1 0.444444 0.444444 - 2 0.000000 0.000000 - - >>> df.B.groupby(df.A).mad() - A - 1 0.444444 - 2 0.000000 - Name: B, dtype: float64 - - See Also - -------- - pyspark.pandas.Series.groupby - pyspark.pandas.DataFrame.groupby - """ - warnings.warn( - "The 'mad' method is deprecated and will be removed in a future version. " - "To compute the same result, you may do `(group_df - group_df.mean()).abs().mean()`.", - FutureWarning, - ) - groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(self._groupkeys))] - internal, agg_columns, sdf = self._prepare_reduce( - groupkey_names=groupkey_names, - accepted_spark_types=(NumericType, BooleanType), - bool_to_numeric=False, - ) - psdf: DataFrame = DataFrame(internal) - - if len(psdf._internal.column_labels) > 0: - window = Window.partitionBy(groupkey_names).rowsBetween( - Window.unboundedPreceding, Window.unboundedFollowing - ) - new_agg_scols = {} - new_stat_scols = [] - for agg_column in agg_columns: - # it is not able to directly use 'self._reduce_for_stat_function', due to - # 'it is not allowed to use a window function inside an aggregate function'. - # so we need to create temporary columns to compute the 'abs(x - avg(x))' here. - agg_column_name = agg_column._internal.data_spark_column_names[0] - new_agg_column_name = verify_temp_column_name( - psdf._internal.spark_frame, "__tmp_agg_col_{}__".format(agg_column_name) - ) - casted_agg_scol = F.col(agg_column_name).cast("double") - new_agg_scols[new_agg_column_name] = F.abs( - casted_agg_scol - F.avg(casted_agg_scol).over(window) - ) - new_stat_scols.append(F.avg(F.col(new_agg_column_name)).alias(agg_column_name)) - - sdf = ( - psdf._internal.spark_frame.withColumns(new_agg_scols) - .groupby(groupkey_names) - .agg(*new_stat_scols) - ) - else: - sdf = sdf.select(*groupkey_names).distinct() - - internal = internal.copy( - spark_frame=sdf, - index_spark_columns=[scol_for(sdf, col) for col in groupkey_names], - data_spark_columns=[scol_for(sdf, col) for col in internal.data_spark_column_names], - data_fields=None, - ) - - return self._prepare_return(DataFrame(internal)) - def sem(self, ddof: int = 1) -> FrameLike: """ Compute standard error of the mean of groups, excluding missing values. diff --git a/python/pyspark/pandas/indexes/__init__.py b/python/pyspark/pandas/indexes/__init__.py index 7fde6ffaf61da..0193d366024cd 100644 --- a/python/pyspark/pandas/indexes/__init__.py +++ b/python/pyspark/pandas/indexes/__init__.py @@ -17,5 +17,4 @@ from pyspark.pandas.indexes.base import Index # noqa: F401 from pyspark.pandas.indexes.datetimes import DatetimeIndex # noqa: F401 from pyspark.pandas.indexes.multi import MultiIndex # noqa: F401 -from pyspark.pandas.indexes.numeric import Float64Index, Int64Index # noqa: F401 from pyspark.pandas.indexes.timedelta import TimedeltaIndex # noqa: F401 diff --git a/python/pyspark/pandas/indexes/base.py b/python/pyspark/pandas/indexes/base.py index a8fd07aa2a73d..4c2ab13743592 100644 --- a/python/pyspark/pandas/indexes/base.py +++ b/python/pyspark/pandas/indexes/base.py @@ -51,7 +51,6 @@ from pyspark.sql import functions as F from pyspark.sql.types import ( DayTimeIntervalType, - FractionalType, IntegralType, TimestampType, TimestampNTZType, @@ -112,19 +111,17 @@ class Index(IndexOpsMixin): -------- MultiIndex : A multi-level, or hierarchical, Index. DatetimeIndex : Index of datetime64 data. - Int64Index : A special case of :class:`Index` with purely integer labels. - Float64Index : A special case of :class:`Index` with purely float labels. Examples -------- - >>> ps.DataFrame({'a': ['a', 'b', 'c']}, index=[1, 2, 3]).index # doctest: +SKIP - Int64Index([1, 2, 3], dtype='int64') + >>> ps.DataFrame({'a': ['a', 'b', 'c']}, index=[1, 2, 3]).index + Index([1, 2, 3], dtype='int64') - >>> ps.DataFrame({'a': [1, 2, 3]}, index=list('abc')).index # doctest: +SKIP + >>> ps.DataFrame({'a': [1, 2, 3]}, index=list('abc')).index Index(['a', 'b', 'c'], dtype='object') - >>> ps.Index([1, 2, 3]) # doctest: +SKIP - Int64Index([1, 2, 3], dtype='int64') + >>> ps.Index([1, 2, 3]) + Index([1, 2, 3], dtype='int64') >>> ps.Index(list('abc')) Index(['a', 'b', 'c'], dtype='object') @@ -132,14 +129,14 @@ class Index(IndexOpsMixin): From a Series: >>> s = ps.Series([1, 2, 3], index=[10, 20, 30]) - >>> ps.Index(s) # doctest: +SKIP - Int64Index([1, 2, 3], dtype='int64') + >>> ps.Index(s) + Index([1, 2, 3], dtype='int64') From an Index: >>> idx = ps.Index([1, 2, 3]) - >>> ps.Index(idx) # doctest: +SKIP - Int64Index([1, 2, 3], dtype='int64') + >>> ps.Index(idx) + Index([1, 2, 3], dtype='int64') """ def __new__( @@ -198,7 +195,6 @@ def _new_instance(anchor: DataFrame) -> "Index": from pyspark.pandas.indexes.category import CategoricalIndex from pyspark.pandas.indexes.datetimes import DatetimeIndex from pyspark.pandas.indexes.multi import MultiIndex - from pyspark.pandas.indexes.numeric import Float64Index, Int64Index from pyspark.pandas.indexes.timedelta import TimedeltaIndex instance: Index @@ -206,14 +202,6 @@ def _new_instance(anchor: DataFrame) -> "Index": instance = object.__new__(MultiIndex) elif isinstance(anchor._internal.index_fields[0].dtype, CategoricalDtype): instance = object.__new__(CategoricalIndex) - elif isinstance( - anchor._internal.spark_type_for(anchor._internal.index_spark_columns[0]), IntegralType - ): - instance = object.__new__(Int64Index) - elif isinstance( - anchor._internal.spark_type_for(anchor._internal.index_spark_columns[0]), FractionalType - ): - instance = object.__new__(Float64Index) elif isinstance( anchor._internal.spark_type_for(anchor._internal.index_spark_columns[0]), (TimestampType, TimestampNTZType), @@ -800,8 +788,8 @@ def rename(self, name: Union[Name, List[Name]], inplace: bool = False) -> Option Examples -------- >>> df = ps.DataFrame({'a': ['A', 'C'], 'b': ['A', 'B']}, columns=['a', 'b']) - >>> df.index.rename("c") # doctest: +SKIP - Int64Index([0, 1], dtype='int64', name='c') + >>> df.index.rename("c") + Index([0, 1], dtype='int64', name='c') >>> df.set_index("a", inplace=True) >>> df.index.rename("d") @@ -869,11 +857,11 @@ def fillna(self, value: Scalar) -> "Index": Examples -------- >>> idx = ps.Index([1, 2, None]) - >>> idx # doctest: +SKIP - Float64Index([1.0, 2.0, nan], dtype='float64') + >>> idx + Index([1.0, 2.0, nan], dtype='float64') - >>> idx.fillna(0) # doctest: +SKIP - Float64Index([1.0, 2.0, 0.0], dtype='float64') + >>> idx.fillna(0) + Index([1.0, 2.0, 0.0], dtype='float64') """ if not isinstance(value, (float, int, str, bool)): raise TypeError("Unsupported type %s" % type(value).__name__) @@ -1242,8 +1230,7 @@ def unique(self, level: Optional[Union[int, Name]] = None) -> "Index": Examples -------- >>> ps.DataFrame({'a': ['a', 'b', 'c']}, index=[1, 1, 3]).index.unique().sort_values() - ... # doctest: +SKIP - Int64Index([1, 3], dtype='int64') + Index([1, 3], dtype='int64') >>> ps.DataFrame({'a': ['a', 'b', 'c']}, index=['d', 'e', 'e']).index.unique().sort_values() Index(['d', 'e'], dtype='object') @@ -1287,11 +1274,11 @@ def drop(self, labels: List[Any]) -> "Index": Examples -------- >>> index = ps.Index([1, 2, 3]) - >>> index # doctest: +SKIP - Int64Index([1, 2, 3], dtype='int64') + >>> index + Index([1, 2, 3], dtype='int64') - >>> index.drop([1]) # doctest: +SKIP - Int64Index([2, 3], dtype='int64') + >>> index.drop([1]) + Index([2, 3], dtype='int64') """ internal = self._internal.resolved_copy sdf = internal.spark_frame[~internal.index_spark_columns[0].isin(labels)] @@ -1406,8 +1393,8 @@ def droplevel(self, level: Union[int, Name, List[Union[int, Name]]]) -> "Index": MultiIndex([('a', 'b', 1), ('x', 'y', 2)], ) - >>> midx.droplevel([0, 1]) # doctest: +SKIP - Int64Index([1, 2], dtype='int64') + >>> midx.droplevel([0, 1]) + Index([1, 2], dtype='int64') >>> midx.droplevel(0) # doctest: +SKIP MultiIndex([('b', 1), ('y', 2)], @@ -1510,23 +1497,23 @@ def symmetric_difference( >>> s1 = ps.Series([1, 2, 3, 4], index=[1, 2, 3, 4]) >>> s2 = ps.Series([1, 2, 3, 4], index=[2, 3, 4, 5]) - >>> s1.index.symmetric_difference(s2.index) # doctest: +SKIP - Int64Index([5, 1], dtype='int64') + >>> s1.index.symmetric_difference(s2.index) + Index([1, 5], dtype='int64') You can set name of result Index. - >>> s1.index.symmetric_difference(s2.index, result_name='pandas-on-Spark') # doctest: +SKIP - Int64Index([5, 1], dtype='int64', name='pandas-on-Spark') + >>> s1.index.symmetric_difference(s2.index, result_name='pandas-on-Spark') + Index([1, 5], dtype='int64', name='pandas-on-Spark') You can set sort to `True`, if you want to sort the resulting index. - >>> s1.index.symmetric_difference(s2.index, sort=True) # doctest: +SKIP - Int64Index([1, 5], dtype='int64') + >>> s1.index.symmetric_difference(s2.index, sort=True) + Index([1, 5], dtype='int64') You can also use the ``^`` operator: - >>> s1.index ^ s2.index # doctest: +SKIP - Int64Index([5, 1], dtype='int64') + >>> (s1.index ^ s2.index) + Index([1, 5], dtype='int64') """ if type(self) != type(other): raise NotImplementedError( @@ -1592,23 +1579,23 @@ def sort_values( Examples -------- >>> idx = ps.Index([10, 100, 1, 1000]) - >>> idx # doctest: +SKIP - Int64Index([10, 100, 1, 1000], dtype='int64') + >>> idx + Index([10, 100, 1, 1000], dtype='int64') Sort values in ascending order (default behavior). - >>> idx.sort_values() # doctest: +SKIP - Int64Index([1, 10, 100, 1000], dtype='int64') + >>> idx.sort_values() + Index([1, 10, 100, 1000], dtype='int64') Sort values in descending order. - >>> idx.sort_values(ascending=False) # doctest: +SKIP - Int64Index([1000, 100, 10, 1], dtype='int64') + >>> idx.sort_values(ascending=False) + Index([1000, 100, 10, 1], dtype='int64') Sort values in descending order, and also get the indices idx was sorted by. - >>> idx.sort_values(ascending=False, return_indexer=True) # doctest: +SKIP - (Int64Index([1000, 100, 10, 1], dtype='int64'), Int64Index([3, 1, 0, 2], dtype='int64')) + >>> idx.sort_values(ascending=False, return_indexer=True) + (Index([1000, 100, 10, 1], dtype='int64'), Index([3, 1, 0, 2], dtype='int64')) Support for MultiIndex. @@ -1631,11 +1618,11 @@ def sort_values( ('a', 'x', 1)], ) - >>> psidx.sort_values(ascending=False, return_indexer=True) # doctest: +SKIP + >>> psidx.sort_values(ascending=False, return_indexer=True) (MultiIndex([('c', 'y', 2), ('b', 'z', 3), ('a', 'x', 1)], - ), Int64Index([1, 2, 0], dtype='int64')) + ), Index([1, 2, 0], dtype='int64')) """ sdf = self._internal.spark_frame if return_indexer: @@ -1772,14 +1759,14 @@ def delete(self, loc: Union[int, List[int]]) -> "Index": Examples -------- >>> psidx = ps.Index([10, 10, 9, 8, 4, 2, 4, 4, 2, 2, 10, 10]) - >>> psidx # doctest: +SKIP - Int64Index([10, 10, 9, 8, 4, 2, 4, 4, 2, 2, 10, 10], dtype='int64') + >>> psidx + Index([10, 10, 9, 8, 4, 2, 4, 4, 2, 2, 10, 10], dtype='int64') - >>> psidx.delete(0).sort_values() # doctest: +SKIP - Int64Index([2, 2, 2, 4, 4, 4, 8, 9, 10, 10, 10], dtype='int64') + >>> psidx.delete(0).sort_values() + Index([2, 2, 2, 4, 4, 4, 8, 9, 10, 10, 10], dtype='int64') - >>> psidx.delete([0, 1, 2, 3, 10, 11]).sort_values() # doctest: +SKIP - Int64Index([2, 2, 2, 4, 4, 4], dtype='int64') + >>> psidx.delete([0, 1, 2, 3, 10, 11]).sort_values() + Index([2, 2, 2, 4, 4, 4], dtype='int64') MultiIndex @@ -1888,11 +1875,11 @@ def append(self, other: "Index") -> "Index": Examples -------- >>> psidx = ps.Index([10, 5, 0, 5, 10, 5, 0, 10]) - >>> psidx # doctest: +SKIP - Int64Index([10, 5, 0, 5, 10, 5, 0, 10], dtype='int64') + >>> psidx + Index([10, 5, 0, 5, 10, 5, 0, 10], dtype='int64') - >>> psidx.append(psidx) # doctest: +SKIP - Int64Index([10, 5, 0, 5, 10, 5, 0, 10, 10, 5, 0, 5, 10, 5, 0, 10], dtype='int64') + >>> psidx.append(psidx) + Index([10, 5, 0, 5, 10, 5, 0, 10, 10, 5, 0, 5, 10, 5, 0, 10], dtype='int64') Support for MiltiIndex @@ -1962,8 +1949,8 @@ def argmax(self) -> int: Examples -------- >>> psidx = ps.Index([10, 9, 8, 7, 100, 5, 4, 3, 100, 3]) - >>> psidx # doctest: +SKIP - Int64Index([10, 9, 8, 7, 100, 5, 4, 3, 100, 3], dtype='int64') + >>> psidx + Index([10, 9, 8, 7, 100, 5, 4, 3, 100, 3], dtype='int64') >>> psidx.argmax() 4 @@ -2010,8 +1997,8 @@ def argmin(self) -> int: Examples -------- >>> psidx = ps.Index([10, 9, 8, 7, 100, 5, 4, 3, 100, 3]) - >>> psidx # doctest: +SKIP - Int64Index([10, 9, 8, 7, 100, 5, 4, 3, 100, 3], dtype='int64') + >>> psidx + Index([10, 9, 8, 7, 100, 5, 4, 3, 100, 3], dtype='int64') >>> psidx.argmin() 7 @@ -2062,11 +2049,11 @@ def set_names( Examples -------- >>> idx = ps.Index([1, 2, 3, 4]) - >>> idx # doctest: +SKIP - Int64Index([1, 2, 3, 4], dtype='int64') + >>> idx + Index([1, 2, 3, 4], dtype='int64') - >>> idx.set_names('quarter') # doctest: +SKIP - Int64Index([1, 2, 3, 4], dtype='int64', name='quarter') + >>> idx.set_names('quarter') + Index([1, 2, 3, 4], dtype='int64', name='quarter') For MultiIndex @@ -2119,8 +2106,8 @@ def difference(self, other: "Index", sort: Optional[bool] = None) -> "Index": >>> idx1 = ps.Index([2, 1, 3, 4]) >>> idx2 = ps.Index([3, 4, 5, 6]) - >>> idx1.difference(idx2, sort=True) # doctest: +SKIP - Int64Index([1, 2], dtype='int64') + >>> idx1.difference(idx2, sort=True) + Index([1, 2], dtype='int64') MultiIndex @@ -2136,7 +2123,7 @@ def difference(self, other: "Index", sort: Optional[bool] = None) -> "Index": # Check if the `self` and `other` have different index types. # 1. `self` is Index, `other` is MultiIndex # 2. `self` is MultiIndex, `other` is Index - is_index_types_different = isinstance(other, Index) and not isinstance(self, type(other)) + is_index_types_different = isinstance(other, Index) and (type(self) != type(other)) if is_index_types_different: if isinstance(self, MultiIndex): # In case `self` is MultiIndex and `other` is Index, @@ -2219,8 +2206,8 @@ def is_all_dates(self) -> bool: True >>> idx = ps.Index([0, 1, 2]) - >>> idx # doctest: +SKIP - Int64Index([0, 1, 2], dtype='int64') + >>> idx + Index([0, 1, 2], dtype='int64') >>> idx.is_all_dates False @@ -2403,8 +2390,8 @@ def union( >>> idx1 = ps.Index([1, 2, 3, 4]) >>> idx2 = ps.Index([3, 4, 5, 6]) - >>> idx1.union(idx2).sort_values() # doctest: +SKIP - Int64Index([1, 2, 3, 4, 5, 6], dtype='int64') + >>> idx1.union(idx2).sort_values() + Index([1, 2, 3, 4, 5, 6], dtype='int64') MultiIndex @@ -2469,8 +2456,8 @@ def holds_integer(self) -> bool: When Index contains null values the result can be different with pandas since pandas-on-Spark cast integer to float when Index contains null values. - >>> ps.Index([1, 2, 3, None]) # doctest: +SKIP - Float64Index([1.0, 2.0, 3.0, nan], dtype='float64') + >>> ps.Index([1, 2, 3, None]) + Index([1.0, 2.0, 3.0, nan], dtype='float64') Examples -------- @@ -2510,8 +2497,8 @@ def intersection(self, other: Union[DataFrame, Series, "Index", List]) -> "Index -------- >>> idx1 = ps.Index([1, 2, 3, 4]) >>> idx2 = ps.Index([3, 4, 5, 6]) - >>> idx1.intersection(idx2).sort_values() # doctest: +SKIP - Int64Index([3, 4], dtype='int64') + >>> idx1.intersection(idx2).sort_values() + Index([3, 4], dtype='int64') """ from pyspark.pandas.indexes.multi import MultiIndex @@ -2599,14 +2586,14 @@ def insert(self, loc: int, item: Any) -> "Index": Examples -------- >>> psidx = ps.Index([1, 2, 3, 4, 5]) - >>> psidx.insert(3, 100) # doctest: +SKIP - Int64Index([1, 2, 3, 100, 4, 5], dtype='int64') + >>> psidx.insert(3, 100) + Index([1, 2, 3, 100, 4, 5], dtype='int64') For negative values >>> psidx = ps.Index([1, 2, 3, 4, 5]) - >>> psidx.insert(-3, 100) # doctest: +SKIP - Int64Index([1, 2, 100, 3, 4, 5], dtype='int64') + >>> psidx.insert(-3, 100) + Index([1, 2, 100, 3, 4, 5], dtype='int64') """ validate_index_loc(self, loc) loc = loc + len(self) if loc < 0 else loc diff --git a/python/pyspark/pandas/indexes/category.py b/python/pyspark/pandas/indexes/category.py index 7bc87805e1552..94725f90679a6 100644 --- a/python/pyspark/pandas/indexes/category.py +++ b/python/pyspark/pandas/indexes/category.py @@ -141,8 +141,8 @@ def codes(self) -> Index: CategoricalIndex(['a', 'b', 'b', 'c', 'c', 'c'], categories=['a', 'b', 'c'], ordered=False, dtype='category') - >>> idx.codes # doctest: +SKIP - Int64Index([0, 1, 1, 2, 2, 2], dtype='int64') + >>> idx.codes + Index([0, 1, 1, 2, 2, 2], dtype='int8') """ return self._with_new_scol( self.spark.column, diff --git a/python/pyspark/pandas/indexes/datetimes.py b/python/pyspark/pandas/indexes/datetimes.py index 9adef61087a9e..1971d90a74272 100644 --- a/python/pyspark/pandas/indexes/datetimes.py +++ b/python/pyspark/pandas/indexes/datetimes.py @@ -261,7 +261,7 @@ def dayofweek(self) -> Index: -------- >>> idx = ps.date_range('2016-12-31', '2017-01-08', freq='D') # doctest: +SKIP >>> idx.dayofweek # doctest: +SKIP - Int64Index([5, 6, 0, 1, 2, 3, 4, 5, 6], dtype='int64') + Index([5, 6, 0, 1, 2, 3, 4, 5, 6], dtype='int64') """ warnings.warn( "`dayofweek` will return int32 index instead of int 64 index in 4.0.0.", @@ -737,13 +737,13 @@ def indexer_between_time( dtype='datetime64[ns]', freq=None) >>> psidx.indexer_between_time("00:01", "00:02").sort_values() # doctest: +SKIP - Int64Index([1, 2], dtype='int64') + Index([1, 2], dtype='int64') >>> psidx.indexer_between_time("00:01", "00:02", include_end=False) # doctest: +SKIP - Int64Index([1], dtype='int64') + Index([1], dtype='int64') >>> psidx.indexer_between_time("00:01", "00:02", include_start=False) # doctest: +SKIP - Int64Index([2], dtype='int64') + Index([2], dtype='int64') """ def pandas_between_time(pdf) -> ps.DataFrame[int]: # type: ignore[no-untyped-def] @@ -783,10 +783,10 @@ def indexer_at_time(self, time: Union[datetime.time, str], asof: bool = False) - dtype='datetime64[ns]', freq=None) >>> psidx.indexer_at_time("00:00") # doctest: +SKIP - Int64Index([0], dtype='int64') + Index([0], dtype='int64') >>> psidx.indexer_at_time("00:01") # doctest: +SKIP - Int64Index([1], dtype='int64') + Index([1], dtype='int64') """ if asof: raise NotImplementedError("'asof' argument is not supported") diff --git a/python/pyspark/pandas/indexes/numeric.py b/python/pyspark/pandas/indexes/numeric.py deleted file mode 100644 index d0b5bc5d15989..0000000000000 --- a/python/pyspark/pandas/indexes/numeric.py +++ /dev/null @@ -1,210 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import warnings -from typing import Any, Optional, Union, cast - -import pandas as pd -from pandas.api.types import is_hashable # type: ignore[attr-defined] - -from pyspark import pandas as ps -from pyspark.pandas._typing import Dtype, Name -from pyspark.pandas.indexes.base import Index -from pyspark.pandas.series import Series - - -class NumericIndex(Index): - """ - Provide numeric type operations. - This is an abstract class. - """ - - pass - - -class IntegerIndex(NumericIndex): - """ - This is an abstract class for Int64Index. - """ - - pass - - -class Int64Index(IntegerIndex): - """ - Immutable sequence used for indexing and alignment. The basic object - storing axis labels for all pandas objects. Int64Index is a special case - of `Index` with purely integer labels. - - .. deprecated:: 3.4.0 - - Parameters - ---------- - data : array-like (1-dimensional) - dtype : NumPy dtype (default: int64) - copy : bool - Make a copy of input ndarray. - name : object - Name to be stored in the index. - - See Also - -------- - Index : The base pandas-on-Spark Index type. - Float64Index : A special case of :class:`Index` with purely float labels. - - Notes - ----- - An Index instance can **only** contain hashable objects. - - Examples - -------- - >>> ps.Int64Index([1, 2, 3]) # doctest: +SKIP - Int64Index([1, 2, 3], dtype='int64') - - From a Series: - - >>> s = ps.Series([1, 2, 3], index=[10, 20, 30]) - >>> ps.Int64Index(s) # doctest: +SKIP - Int64Index([1, 2, 3], dtype='int64') - - From an Index: - - >>> idx = ps.Index([1, 2, 3]) - >>> ps.Int64Index(idx) # doctest: +SKIP - Int64Index([1, 2, 3], dtype='int64') - """ - - def __new__( - cls, - data: Optional[Any] = None, - dtype: Optional[Union[str, Dtype]] = None, - copy: bool = False, - name: Optional[Name] = None, - ) -> "Int64Index": - warnings.warn( - "Int64Index is deprecated in 3.4.0, and will be removed in 4.0.0. Use Index instead.", - FutureWarning, - ) - if not is_hashable(name): - raise TypeError("Index.name must be a hashable type") - - if isinstance(data, (Series, Index)): - if dtype is None: - dtype = "int64" - return cast(Int64Index, Index(data, dtype=dtype, copy=copy, name=name)) - - return cast( - Int64Index, ps.from_pandas(pd.Int64Index(data=data, dtype=dtype, copy=copy, name=name)) - ) - - -class Float64Index(NumericIndex): - """ - Immutable sequence used for indexing and alignment. The basic object - storing axis labels for all pandas objects. Float64Index is a special case - of `Index` with purely float labels. - - .. deprecated:: 3.4.0 - - Parameters - ---------- - data : array-like (1-dimensional) - dtype : NumPy dtype (default: float64) - copy : bool - Make a copy of input ndarray. - name : object - Name to be stored in the index. - - See Also - -------- - Index : The base pandas-on-Spark Index type. - Int64Index : A special case of :class:`Index` with purely integer labels. - - Notes - ----- - An Index instance can **only** contain hashable objects. - - Examples - -------- - >>> ps.Float64Index([1.0, 2.0, 3.0]) # doctest: +SKIP - Float64Index([1.0, 2.0, 3.0], dtype='float64') - - From a Series: - - >>> s = ps.Series([1, 2, 3], index=[10, 20, 30]) - >>> ps.Float64Index(s) # doctest: +SKIP - Float64Index([1.0, 2.0, 3.0], dtype='float64') - - From an Index: - - >>> idx = ps.Index([1, 2, 3]) - >>> ps.Float64Index(idx) # doctest: +SKIP - Float64Index([1.0, 2.0, 3.0], dtype='float64') - """ - - def __new__( - cls, - data: Optional[Any] = None, - dtype: Optional[Union[str, Dtype]] = None, - copy: bool = False, - name: Optional[Name] = None, - ) -> "Float64Index": - warnings.warn( - "Float64Index is deprecated in 3.4.0, and will be removed in 4.0.0. Use Index instead.", - FutureWarning, - ) - if not is_hashable(name): - raise TypeError("Index.name must be a hashable type") - - if isinstance(data, (Series, Index)): - if dtype is None: - dtype = "float64" - return cast(Float64Index, Index(data, dtype=dtype, copy=copy, name=name)) - - return cast( - Float64Index, - ps.from_pandas(pd.Float64Index(data=data, dtype=dtype, copy=copy, name=name)), - ) - - -def _test() -> None: - import os - import doctest - import sys - from pyspark.sql import SparkSession - import pyspark.pandas.indexes.numeric - - os.chdir(os.environ["SPARK_HOME"]) - - globs = pyspark.pandas.indexes.numeric.__dict__.copy() - globs["ps"] = pyspark.pandas - spark = ( - SparkSession.builder.master("local[4]") - .appName("pyspark.pandas.indexes.numeric tests") - .getOrCreate() - ) - (failure_count, test_count) = doctest.testmod( - pyspark.pandas.indexes.numeric, - globs=globs, - optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE, - ) - spark.stop() - if failure_count: - sys.exit(-1) - - -if __name__ == "__main__": - _test() diff --git a/python/pyspark/pandas/namespace.py b/python/pyspark/pandas/namespace.py index 3563a6d81b4fa..5ffec6bedb988 100644 --- a/python/pyspark/pandas/namespace.py +++ b/python/pyspark/pandas/namespace.py @@ -2365,7 +2365,6 @@ def concat( See Also -------- - Series.append : Concatenate Series. DataFrame.join : Join DataFrames using indexes. DataFrame.merge : Merge DataFrames by indexes or columns. diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py index 95ca92e78787d..a74f36986f3b5 100644 --- a/python/pyspark/pandas/series.py +++ b/python/pyspark/pandas/series.py @@ -492,8 +492,8 @@ def axes(self) -> List["Index"]: -------- >>> psser = ps.Series([1, 2, 3]) - >>> psser.axes # doctest: +SKIP - [Int64Index([0, 1, 2], dtype='int64')] + >>> psser.axes + [Index([0, 1, 2], dtype='int64')] """ return [self.index] @@ -3584,71 +3584,6 @@ def nlargest(self, n: int = 5) -> "Series": """ return self.sort_values(ascending=False).head(n) - def append( - self, to_append: "Series", ignore_index: bool = False, verify_integrity: bool = False - ) -> "Series": - """ - Concatenate two or more Series. - - .. deprecated:: 3.4.0 - - Parameters - ---------- - to_append : Series or list/tuple of Series - ignore_index : boolean, default False - If True, do not use the index labels. - verify_integrity : boolean, default False - If True, raise Exception on creating index with duplicates - - Returns - ------- - appended : Series - - Examples - -------- - >>> s1 = ps.Series([1, 2, 3]) - >>> s2 = ps.Series([4, 5, 6]) - >>> s3 = ps.Series([4, 5, 6], index=[3,4,5]) - - >>> s1.append(s2) # doctest: +SKIP - 0 1 - 1 2 - 2 3 - 0 4 - 1 5 - 2 6 - dtype: int64 - - >>> s1.append(s3) # doctest: +SKIP - 0 1 - 1 2 - 2 3 - 3 4 - 4 5 - 5 6 - dtype: int64 - - With ignore_index set to True: - - >>> s1.append(s2, ignore_index=True) # doctest: +SKIP - 0 1 - 1 2 - 2 3 - 3 4 - 4 5 - 5 6 - dtype: int64 - """ - warnings.warn( - "The Series.append method is deprecated " - "and will be removed in 4.0.0. " - "Use pyspark.pandas.concat instead.", - FutureWarning, - ) - return first_series( - self.to_frame().append(to_append.to_frame(), ignore_index, verify_integrity) - ).rename(self.name) - def sample( self, n: Optional[int] = None, @@ -5939,37 +5874,6 @@ def asof(self, where: Union[Any, List]) -> Union[Scalar, "Series"]: pdf.columns = pd.Index(where) return first_series(DataFrame(pdf.transpose())).rename(self.name) - def mad(self) -> float: - """ - Return the mean absolute deviation of values. - - .. deprecated:: 3.4.0 - - Examples - -------- - >>> s = ps.Series([1, 2, 3, 4]) - >>> s - 0 1 - 1 2 - 2 3 - 3 4 - dtype: int64 - - >>> s.mad() - 1.0 - """ - warnings.warn( - "The 'mad' method is deprecated and will be removed in 4.0.0. " - "To compute the same result, you may do `(series - series.mean()).abs().mean()`.", - FutureWarning, - ) - sdf = self._internal.spark_frame - spark_column = self.spark.column - avg = unpack_scalar(sdf.select(F.avg(spark_column))) - mad = unpack_scalar(sdf.select(F.avg(F.abs(spark_column - avg)))) - - return mad - def unstack(self, level: int = -1) -> DataFrame: """ Unstack, a.k.a. pivot, Series with MultiIndex to produce DataFrame. @@ -6083,7 +5987,7 @@ def items(self) -> Iterable[Tuple[Name, Any]]: This method returns an iterable tuple (index, value). This is convenient if you want to create a lazy iterator. - .. note:: Unlike pandas', the iteritems in pandas-on-Spark returns generator rather + .. note:: Unlike pandas', the itmes in pandas-on-Spark returns generator rather zip object Returns @@ -6123,20 +6027,6 @@ def extract_kv_from_spark_row(row: Row) -> Tuple[Name, Any]: ): yield k, v - def iteritems(self) -> Iterable[Tuple[Name, Any]]: - """ - This is an alias of ``items``. - - .. deprecated:: 3.4.0 - iteritems is deprecated and will be removed in a future version. - Use .items instead. - """ - warnings.warn( - "Deprecated in 3.4, and will be removed in 4.0.0. Use Series.items instead.", - FutureWarning, - ) - return self.items() - def droplevel(self, level: Union[int, Name, List[Union[int, Name]]]) -> "Series": """ Return Series with requested index level(s) removed. diff --git a/python/pyspark/pandas/spark/accessors.py b/python/pyspark/pandas/spark/accessors.py index f55f70e00924b..bcbe044185a75 100644 --- a/python/pyspark/pandas/spark/accessors.py +++ b/python/pyspark/pandas/spark/accessors.py @@ -105,8 +105,8 @@ def transform(self, func: Callable[[PySparkColumn], PySparkColumn]) -> IndexOpsL 2 1.098612 Name: a, dtype: float64 - >>> df.index.spark.transform(lambda c: c + 10) # doctest: +SKIP - Int64Index([10, 11, 12], dtype='int64') + >>> df.index.spark.transform(lambda c: c + 10) + Index([10, 11, 12], dtype='int64') >>> df.a.spark.transform(lambda c: c + df.b.spark.column) 0 5 @@ -283,13 +283,13 @@ def analyzed(self) -> "ps.Index": -------- >>> import pyspark.pandas as ps >>> idx = ps.Index([1, 2, 3]) - >>> idx # doctest: +SKIP - Int64Index([1, 2, 3], dtype='int64') + >>> idx + Index([1, 2, 3], dtype='int64') The analyzed one should return the same value. - >>> idx.spark.analyzed # doctest: +SKIP - Int64Index([1, 2, 3], dtype='int64') + >>> idx.spark.analyzed + Index([1, 2, 3], dtype='int64') However, it won't work with the same anchor Index. @@ -299,8 +299,8 @@ def analyzed(self) -> "ps.Index": ValueError: ... enable 'compute.ops_on_diff_frames' option. >>> with ps.option_context('compute.ops_on_diff_frames', True): - ... (idx + idx.spark.analyzed).sort_values() # doctest: +SKIP - Int64Index([2, 4, 6], dtype='int64') + ... (idx + idx.spark.analyzed).sort_values() + Index([2, 4, 6], dtype='int64') """ from pyspark.pandas.frame import DataFrame diff --git a/python/pyspark/pandas/tests/computation/test_combine.py b/python/pyspark/pandas/tests/computation/test_combine.py index dd55c0fd68661..adba20b5d99b3 100644 --- a/python/pyspark/pandas/tests/computation/test_combine.py +++ b/python/pyspark/pandas/tests/computation/test_combine.py @@ -41,46 +41,26 @@ def df_pair(self): psdf = ps.from_pandas(pdf) return pdf, psdf - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43562): Enable DataFrameTests.test_append for pandas 2.0.0.", - ) - def test_append(self): + def test_concat(self): pdf = pd.DataFrame([[1, 2], [3, 4]], columns=list("AB")) psdf = ps.from_pandas(pdf) other_pdf = pd.DataFrame([[3, 4], [5, 6]], columns=list("BC"), index=[2, 3]) other_psdf = ps.from_pandas(other_pdf) - self.assert_eq(psdf.append(psdf), pdf.append(pdf)) - self.assert_eq(psdf.append(psdf, ignore_index=True), pdf.append(pdf, ignore_index=True)) + self.assert_eq(ps.concat([psdf, psdf]), pd.concat([pdf, pdf])) + self.assert_eq( + ps.concat([psdf, psdf], ignore_index=True), pd.concat([pdf, pdf], ignore_index=True) + ) # Assert DataFrames with non-matching columns - self.assert_eq(psdf.append(other_psdf), pdf.append(other_pdf)) - - # Assert appending a Series fails - msg = "DataFrames.append() does not support appending Series to DataFrames" - with self.assertRaises(TypeError, msg=msg): - psdf.append(psdf["A"]) - - # Assert using the sort parameter raises an exception - msg = "The 'sort' parameter is currently not supported" - with self.assertRaises(NotImplementedError, msg=msg): - psdf.append(psdf, sort=True) + self.assert_eq(ps.concat([psdf, other_psdf]), pd.concat([pdf, other_pdf])) - # Assert using 'verify_integrity' only raises an exception for overlapping indices - self.assert_eq( - psdf.append(other_psdf, verify_integrity=True), - pdf.append(other_pdf, verify_integrity=True), - ) - msg = "Indices have overlapping values" - with self.assertRaises(ValueError, msg=msg): - psdf.append(psdf, verify_integrity=True) + ps.concat([psdf, psdf["A"]]) + # Assert appending a Series + self.assert_eq(ps.concat([psdf, psdf["A"]]), pd.concat([pdf, pdf["A"]])) - # Skip integrity verification when ignore_index=True - self.assert_eq( - psdf.append(psdf, ignore_index=True, verify_integrity=True), - pdf.append(pdf, ignore_index=True, verify_integrity=True), - ) + # Assert using the sort parameter + self.assert_eq(ps.concat([psdf, psdf], sort=True), pd.concat([pdf, pdf], sort=True)) # Assert appending multi-index DataFrames multi_index_pdf = pd.DataFrame([[1, 2], [3, 4]], columns=list("AB"), index=[[2, 3], [4, 5]]) @@ -91,45 +71,32 @@ def test_append(self): other_multi_index_psdf = ps.from_pandas(other_multi_index_pdf) self.assert_eq( - multi_index_psdf.append(multi_index_psdf), multi_index_pdf.append(multi_index_pdf) + ps.concat([multi_index_psdf, multi_index_psdf]), + pd.concat([multi_index_pdf, multi_index_pdf]), ) # Assert DataFrames with non-matching columns self.assert_eq( - multi_index_psdf.append(other_multi_index_psdf), - multi_index_pdf.append(other_multi_index_pdf), - ) - - # Assert using 'verify_integrity' only raises an exception for overlapping indices - self.assert_eq( - multi_index_psdf.append(other_multi_index_psdf, verify_integrity=True), - multi_index_pdf.append(other_multi_index_pdf, verify_integrity=True), - ) - with self.assertRaises(ValueError, msg=msg): - multi_index_psdf.append(multi_index_psdf, verify_integrity=True) - - # Skip integrity verification when ignore_index=True - self.assert_eq( - multi_index_psdf.append(multi_index_psdf, ignore_index=True, verify_integrity=True), - multi_index_pdf.append(multi_index_pdf, ignore_index=True, verify_integrity=True), + ps.concat([multi_index_psdf, other_multi_index_psdf]), + pd.concat([multi_index_pdf, other_multi_index_pdf]), ) # Assert trying to append DataFrames with different index levels msg = "Both DataFrames have to have the same number of index levels" with self.assertRaises(ValueError, msg=msg): - psdf.append(multi_index_psdf) + ps.concat([psdf, multi_index_psdf]) # Skip index level check when ignore_index=True self.assert_eq( - psdf.append(multi_index_psdf, ignore_index=True), - pdf.append(multi_index_pdf, ignore_index=True), + ps.concat([psdf, other_multi_index_psdf], ignore_index=True), + pd.concat([pdf, other_multi_index_pdf], ignore_index=True), ) columns = pd.MultiIndex.from_tuples([("A", "X"), ("A", "Y")]) pdf.columns = columns psdf.columns = columns - self.assert_eq(psdf.append(psdf), pdf.append(pdf)) + self.assert_eq(ps.concat([psdf, psdf]), pd.concat([pdf, pdf])) def test_merge(self): left_pdf = pd.DataFrame( diff --git a/python/pyspark/pandas/tests/computation/test_compute.py b/python/pyspark/pandas/tests/computation/test_compute.py index 5ce273c1f4769..d4b49f2ac8b01 100644 --- a/python/pyspark/pandas/tests/computation/test_compute.py +++ b/python/pyspark/pandas/tests/computation/test_compute.py @@ -78,40 +78,6 @@ def test_clip(self): str_psdf = ps.DataFrame({"A": ["a", "b", "c"]}, index=np.random.rand(3)) self.assert_eq(str_psdf.clip(1, 3), str_psdf) - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43560): Enable DataFrameSlowTests.test_mad for pandas 2.0.0.", - ) - def test_mad(self): - pdf = pd.DataFrame( - { - "A": [1, 2, None, 4, np.nan], - "B": [-0.1, 0.2, -0.3, np.nan, 0.5], - "C": ["a", "b", "c", "d", "e"], - } - ) - psdf = ps.from_pandas(pdf) - - self.assert_eq(psdf.mad(), pdf.mad()) - self.assert_eq(psdf.mad(axis=1), pdf.mad(axis=1)) - - with self.assertRaises(ValueError): - psdf.mad(axis=2) - - # MultiIndex columns - columns = pd.MultiIndex.from_tuples([("A", "X"), ("A", "Y"), ("A", "Z")]) - pdf.columns = columns - psdf.columns = columns - - self.assert_eq(psdf.mad(), pdf.mad()) - self.assert_eq(psdf.mad(axis=1), pdf.mad(axis=1)) - - pdf = pd.DataFrame({"A": [True, True, False, False], "B": [True, False, False, True]}) - psdf = ps.from_pandas(pdf) - - self.assert_eq(psdf.mad(), pdf.mad()) - self.assert_eq(psdf.mad(axis=1), pdf.mad(axis=1)) - def test_mode(self): pdf = pd.DataFrame( { diff --git a/python/pyspark/pandas/tests/computation/test_describe.py b/python/pyspark/pandas/tests/computation/test_describe.py index af98d2869da9b..bbee9654eae4b 100644 --- a/python/pyspark/pandas/tests/computation/test_describe.py +++ b/python/pyspark/pandas/tests/computation/test_describe.py @@ -39,10 +39,6 @@ def df_pair(self): psdf = ps.from_pandas(pdf) return pdf, psdf - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43556): Enable DataFrameSlowTests.test_describe for pandas 2.0.0.", - ) def test_describe(self): pdf, psdf = self.df_pair @@ -78,19 +74,10 @@ def test_describe(self): } ) pdf = psdf._to_pandas() - # NOTE: Set `datetime_is_numeric=True` for pandas: - # FutureWarning: Treating datetime data as categorical rather than numeric in - # `.describe` is deprecated and will be removed in a future version of pandas. - # Specify `datetime_is_numeric=True` to silence this - # warning and adopt the future behavior now. - # NOTE: Compare the result except percentiles, since we use approximate percentile - # so the result is different from pandas. if LooseVersion(pd.__version__) >= LooseVersion("1.1.0"): self.assert_eq( psdf.describe().loc[["count", "mean", "min", "max"]], - pdf.describe(datetime_is_numeric=True) - .astype(str) - .loc[["count", "mean", "min", "max"]], + pdf.describe().astype(str).loc[["count", "mean", "min", "max"]], ) else: self.assert_eq( @@ -136,17 +123,13 @@ def test_describe(self): if LooseVersion(pd.__version__) >= LooseVersion("1.1.0"): self.assert_eq( psdf.describe().loc[["count", "mean", "min", "max"]], - pdf.describe(datetime_is_numeric=True) - .astype(str) - .loc[["count", "mean", "min", "max"]], + pdf.describe().astype(str).loc[["count", "mean", "min", "max"]], ) psdf.A += psdf.A pdf.A += pdf.A self.assert_eq( psdf.describe().loc[["count", "mean", "min", "max"]], - pdf.describe(datetime_is_numeric=True) - .astype(str) - .loc[["count", "mean", "min", "max"]], + pdf.describe().astype(str).loc[["count", "mean", "min", "max"]], ) else: expected_result = ps.DataFrame( @@ -187,7 +170,7 @@ def test_describe(self): ) pdf = psdf._to_pandas() if LooseVersion(pd.__version__) >= LooseVersion("1.1.0"): - pandas_result = pdf.describe(datetime_is_numeric=True) + pandas_result = pdf.describe() pandas_result.B = pandas_result.B.astype(str) self.assert_eq( psdf.describe().loc[["count", "mean", "min", "max"]], @@ -195,7 +178,7 @@ def test_describe(self): ) psdf.A += psdf.A pdf.A += pdf.A - pandas_result = pdf.describe(datetime_is_numeric=True) + pandas_result = pdf.describe() pandas_result.B = pandas_result.B.astype(str) self.assert_eq( psdf.describe().loc[["count", "mean", "min", "max"]], @@ -252,7 +235,7 @@ def test_describe(self): ) pdf = psdf._to_pandas() if LooseVersion(pd.__version__) >= LooseVersion("1.1.0"): - pandas_result = pdf.describe(datetime_is_numeric=True) + pandas_result = pdf.describe() pandas_result.b = pandas_result.b.astype(str) self.assert_eq( psdf.describe().loc[["count", "mean", "min", "max"]], @@ -288,10 +271,6 @@ def test_describe(self): with self.assertRaisesRegex(ValueError, msg): psdf.describe() - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43556): Enable DataFrameSlowTests.test_describe for pandas 2.0.0.", - ) def test_describe_empty(self): # Empty DataFrame psdf = ps.DataFrame(columns=["A", "B"]) @@ -328,7 +307,7 @@ def test_describe_empty(self): # For timestamp type, we should convert NaT to None in pandas result # since pandas API on Spark doesn't support the NaT for object type. if LooseVersion(pd.__version__) >= LooseVersion("1.1.0"): - pdf_result = pdf[pdf.a != pdf.a].describe(datetime_is_numeric=True) + pdf_result = pdf[pdf.a != pdf.a].describe() self.assert_eq( psdf[psdf.a != psdf.a].describe(), pdf_result.where(pdf_result.notnull(), None).astype(str), @@ -367,7 +346,7 @@ def test_describe_empty(self): ) pdf = psdf._to_pandas() if LooseVersion(pd.__version__) >= LooseVersion("1.1.0"): - pdf_result = pdf[pdf.a != pdf.a].describe(datetime_is_numeric=True) + pdf_result = pdf[pdf.a != pdf.a].describe() pdf_result.b = pdf_result.b.where(pdf_result.b.notnull(), None).astype(str) self.assert_eq( psdf[psdf.a != psdf.a].describe(), @@ -417,7 +396,7 @@ def test_describe_empty(self): ) pdf = psdf._to_pandas() if LooseVersion(pd.__version__) >= LooseVersion("1.1.0"): - pdf_result = pdf[pdf.a != pdf.a].describe(datetime_is_numeric=True) + pdf_result = pdf[pdf.a != pdf.a].describe() self.assert_eq( psdf[psdf.a != psdf.a].describe(), pdf_result.where(pdf_result.notnull(), None).astype(str), diff --git a/python/pyspark/pandas/tests/connect/computation/test_parity_pivot.py b/python/pyspark/pandas/tests/connect/computation/test_parity_pivot.py index d2c4f9ae60717..c8ec48eb06aa4 100644 --- a/python/pyspark/pandas/tests/connect/computation/test_parity_pivot.py +++ b/python/pyspark/pandas/tests/connect/computation/test_parity_pivot.py @@ -16,28 +16,13 @@ # import unittest -from pyspark import pandas as ps from pyspark.pandas.tests.computation.test_pivot import FramePivotMixin from pyspark.testing.connectutils import ReusedConnectTestCase from pyspark.testing.pandasutils import PandasOnSparkTestUtils class FrameParityPivotTests(FramePivotMixin, PandasOnSparkTestUtils, ReusedConnectTestCase): - @property - def psdf(self): - return ps.from_pandas(self.pdf) - - @unittest.skip( - "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark Connect client." - ) - def test_pivot_table(self): - super().test_pivot_table() - - @unittest.skip( - "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark Connect client." - ) - def test_pivot_table_dtypes(self): - super().test_pivot_table_dtypes() + pass if __name__ == "__main__": diff --git a/python/pyspark/pandas/tests/connect/frame/test_parity_reshaping.py b/python/pyspark/pandas/tests/connect/frame/test_parity_reshaping.py index 98ebf3ca44a07..e4bac7b078e66 100644 --- a/python/pyspark/pandas/tests/connect/frame/test_parity_reshaping.py +++ b/python/pyspark/pandas/tests/connect/frame/test_parity_reshaping.py @@ -16,22 +16,13 @@ # import unittest -from pyspark import pandas as ps from pyspark.pandas.tests.frame.test_reshaping import FrameReshapingMixin from pyspark.testing.connectutils import ReusedConnectTestCase from pyspark.testing.pandasutils import PandasOnSparkTestUtils class FrameParityReshapingTests(FrameReshapingMixin, PandasOnSparkTestUtils, ReusedConnectTestCase): - @property - def psdf(self): - return ps.from_pandas(self.pdf) - - @unittest.skip( - "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark Connect client." - ) - def test_transpose(self): - super().test_transpose() + pass if __name__ == "__main__": diff --git a/python/pyspark/pandas/tests/connect/series/test_parity_compute.py b/python/pyspark/pandas/tests/connect/series/test_parity_compute.py index f757d19ca6941..31916f12b4e7f 100644 --- a/python/pyspark/pandas/tests/connect/series/test_parity_compute.py +++ b/python/pyspark/pandas/tests/connect/series/test_parity_compute.py @@ -22,11 +22,11 @@ class SeriesParityComputeTests(SeriesComputeMixin, PandasOnSparkTestUtils, ReusedConnectTestCase): - @unittest.skip( - "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark Connect client." - ) - def test_unstack(self): - super().test_unstack() + pass + + @unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.") + def test_factorize(self): + super().test_factorize() if __name__ == "__main__": diff --git a/python/pyspark/pandas/tests/connect/test_parity_categorical.py b/python/pyspark/pandas/tests/connect/test_parity_categorical.py index 3e05eb2c0f3b7..210cfce8ddbaf 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_categorical.py +++ b/python/pyspark/pandas/tests/connect/test_parity_categorical.py @@ -53,12 +53,6 @@ def test_reorder_categories(self): def test_set_categories(self): super().test_set_categories() - @unittest.skip( - "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark Connect client." - ) - def test_unstack(self): - super().test_unstack() - if __name__ == "__main__": from pyspark.pandas.tests.connect.test_parity_categorical import * # noqa: F401 diff --git a/python/pyspark/pandas/tests/groupby/test_stat.py b/python/pyspark/pandas/tests/groupby/test_stat.py index bfdeeecce303c..8a5096942e689 100644 --- a/python/pyspark/pandas/tests/groupby/test_stat.py +++ b/python/pyspark/pandas/tests/groupby/test_stat.py @@ -206,13 +206,6 @@ def test_sum(self): psdf.groupby("A").sum(min_count=3).sort_index(), ) - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43553): Enable GroupByTests.test_mad for pandas 2.0.0.", - ) - def test_mad(self): - self._test_stat_func(lambda groupby_obj: groupby_obj.mad()) - def test_first(self): self._test_stat_func(lambda groupby_obj: groupby_obj.first()) self._test_stat_func(lambda groupby_obj: groupby_obj.first(numeric_only=None)) diff --git a/python/pyspark/pandas/tests/indexes/test_base.py b/python/pyspark/pandas/tests/indexes/test_base.py index 6cb7c58197f3c..736c88db4a8f5 100644 --- a/python/pyspark/pandas/tests/indexes/test_base.py +++ b/python/pyspark/pandas/tests/indexes/test_base.py @@ -42,10 +42,6 @@ def pdf(self): index=[0, 1, 3, 5, 6, 8, 9, 9, 9], ) - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43606): Enable IndexesTests.test_index_basic for pandas 2.0.0.", - ) def test_index_basic(self): for pdf in [ pd.DataFrame(np.random.randn(10, 5), index=np.random.randint(100, size=10)), @@ -70,22 +66,12 @@ def test_index_basic(self): self.assert_eq(type(psdf.index).__name__, type(pdf.index).__name__) self.assert_eq(ps.Index([])._summary(), "Index: 0 entries") - if LooseVersion(pd.__version__) >= LooseVersion("2.0.0"): - with self.assertRaisesRegexp(ValueError, "The truth value of a Index is ambiguous."): - bool(ps.Index([1])) - with self.assertRaisesRegexp(TypeError, "Index.name must be a hashable type"): - ps.Index([1, 2, 3], name=[(1, 2, 3)]) - with self.assertRaisesRegexp(TypeError, "Index.name must be a hashable type"): - ps.Index([1.0, 2.0, 3.0], name=[(1, 2, 3)]) - else: - with self.assertRaisesRegexp( - ValueError, "The truth value of a Int64Index is ambiguous." - ): - bool(ps.Index([1])) - with self.assertRaisesRegexp(TypeError, "Index.name must be a hashable type"): - ps.Int64Index([1, 2, 3], name=[(1, 2, 3)]) - with self.assertRaisesRegexp(TypeError, "Index.name must be a hashable type"): - ps.Float64Index([1.0, 2.0, 3.0], name=[(1, 2, 3)]) + with self.assertRaisesRegexp(ValueError, "The truth value of a Index is ambiguous."): + bool(ps.Index([1])) + with self.assertRaisesRegexp(TypeError, "Index.name must be a hashable type"): + ps.Index([1, 2, 3], name=[(1, 2, 3)]) + with self.assertRaisesRegexp(TypeError, "Index.name must be a hashable type"): + ps.Index([1.0, 2.0, 3.0], name=[(1, 2, 3)]) def test_index_from_series(self): pser = pd.Series([1, 2, 3], name="a", index=[10, 20, 30]) @@ -95,15 +81,8 @@ def test_index_from_series(self): self.assert_eq(ps.Index(psser, dtype="float"), pd.Index(pser, dtype="float")) self.assert_eq(ps.Index(psser, name="x"), pd.Index(pser, name="x")) - if LooseVersion(pd.__version__) >= LooseVersion("2.0.0"): - self.assert_eq(ps.Index(psser, dtype="int64"), pd.Index(pser, dtype="int64")) - self.assert_eq(ps.Index(psser, dtype="float64"), pd.Index(pser, dtype="float64")) - elif LooseVersion(pd.__version__) >= LooseVersion("1.1"): - self.assert_eq(ps.Int64Index(psser), pd.Int64Index(pser)) - self.assert_eq(ps.Float64Index(psser), pd.Float64Index(pser)) - else: - self.assert_eq(ps.Int64Index(psser), pd.Int64Index(pser).rename("a")) - self.assert_eq(ps.Float64Index(psser), pd.Float64Index(pser).rename("a")) + self.assert_eq(ps.Index(psser, dtype="int64"), pd.Index(pser, dtype="int64")) + self.assert_eq(ps.Index(psser, dtype="float64"), pd.Index(pser, dtype="float64")) pser = pd.Series([datetime(2021, 3, 1), datetime(2021, 3, 2)], name="x", index=[10, 20]) psser = ps.from_pandas(pser) @@ -120,12 +99,8 @@ def test_index_from_index(self): self.assert_eq(ps.Index(psidx, name="x"), pd.Index(pidx, name="x")) self.assert_eq(ps.Index(psidx, copy=True), pd.Index(pidx, copy=True)) - if LooseVersion(pd.__version__) >= LooseVersion("2.0.0"): - self.assert_eq(ps.Index(psidx, dtype="int64"), pd.Index(pidx, dtype="int64")) - self.assert_eq(ps.Index(psidx, dtype="float64"), pd.Index(pidx, dtype="float64")) - else: - self.assert_eq(ps.Int64Index(psidx), pd.Int64Index(pidx)) - self.assert_eq(ps.Float64Index(psidx), pd.Float64Index(pidx)) + self.assert_eq(ps.Index(psidx, dtype="int64"), pd.Index(pidx, dtype="int64")) + self.assert_eq(ps.Index(psidx, dtype="float64"), pd.Index(pidx, dtype="float64")) pidx = pd.DatetimeIndex(["2021-03-01", "2021-03-02"]) psidx = ps.from_pandas(pidx) diff --git a/python/pyspark/pandas/tests/indexes/test_category.py b/python/pyspark/pandas/tests/indexes/test_category.py index ffffae828c437..6aa92b7e1e390 100644 --- a/python/pyspark/pandas/tests/indexes/test_category.py +++ b/python/pyspark/pandas/tests/indexes/test_category.py @@ -210,10 +210,6 @@ def test_astype(self): self.assert_eq(pscidx.astype(str), pcidx.astype(str)) - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43567): Enable CategoricalIndexTests.test_factorize for pandas 2.0.0.", - ) def test_factorize(self): pidx = pd.CategoricalIndex([1, 2, 3, None]) psidx = ps.from_pandas(pidx) @@ -224,8 +220,8 @@ def test_factorize(self): self.assert_eq(kcodes.tolist(), pcodes.tolist()) self.assert_eq(kuniques, puniques) - pcodes, puniques = pidx.factorize(na_sentinel=-2) - kcodes, kuniques = psidx.factorize(na_sentinel=-2) + pcodes, puniques = pidx.factorize(use_na_sentinel=-2) + kcodes, kuniques = psidx.factorize(use_na_sentinel=-2) self.assert_eq(kcodes.tolist(), pcodes.tolist()) self.assert_eq(kuniques, puniques) diff --git a/python/pyspark/pandas/tests/indexes/test_indexing.py b/python/pyspark/pandas/tests/indexes/test_indexing.py index 64fc75347baf3..111dd09696d79 100644 --- a/python/pyspark/pandas/tests/indexes/test_indexing.py +++ b/python/pyspark/pandas/tests/indexes/test_indexing.py @@ -53,11 +53,7 @@ def test_head(self): with option_context("compute.ordered_head", True): self.assert_eq(psdf.head(), pdf.head()) - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43559): Enable DataFrameSlowTests.test_iteritems for pandas 2.0.0.", - ) - def test_iteritems(self): + def test_items(self): pdf = pd.DataFrame( {"species": ["bear", "bear", "marsupial"], "population": [1864, 22000, 80000]}, index=["panda", "polar", "koala"], @@ -65,7 +61,7 @@ def test_iteritems(self): ) psdf = ps.from_pandas(pdf) - for (p_name, p_items), (k_name, k_items) in zip(pdf.iteritems(), psdf.iteritems()): + for (p_name, p_items), (k_name, k_items) in zip(pdf.items(), psdf.items()): self.assert_eq(p_name, k_name) self.assert_eq(p_items, k_items) diff --git a/python/pyspark/pandas/tests/series/test_compute.py b/python/pyspark/pandas/tests/series/test_compute.py index 2fbdaef865e50..784bf29e1a25b 100644 --- a/python/pyspark/pandas/tests/series/test_compute.py +++ b/python/pyspark/pandas/tests/series/test_compute.py @@ -142,11 +142,7 @@ def test_compare(self): expected = ps.DataFrame([[1, 2], [2, 3]], index=["x", "y"], columns=["self", "other"]) self.assert_eq(expected, psser.compare(psser + 1).sort_index()) - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43465): Enable SeriesTests.test_append for pandas 2.0.0.", - ) - def test_append(self): + def test_concat(self): pser1 = pd.Series([1, 2, 3], name="0") pser2 = pd.Series([4, 5, 6], name="0") pser3 = pd.Series([4, 5, 6], index=[3, 4, 5], name="0") @@ -154,17 +150,13 @@ def test_append(self): psser2 = ps.from_pandas(pser2) psser3 = ps.from_pandas(pser3) - self.assert_eq(psser1.append(psser2), pser1.append(pser2)) - self.assert_eq(psser1.append(psser3), pser1.append(pser3)) + self.assert_eq(ps.concat([psser1, psser2]), pd.concat([pser1, pser2])) + self.assert_eq(ps.concat([psser1, psser3]), pd.concat([pser1, pser3])) self.assert_eq( - psser1.append(psser2, ignore_index=True), pser1.append(pser2, ignore_index=True) + ps.concat([psser1, psser2], ignore_index=True), + pd.concat([pser1, pser2], ignore_index=True), ) - psser1.append(psser3, verify_integrity=True) - msg = "Indices have overlapping values" - with self.assertRaises(ValueError, msg=msg): - psser1.append(psser2, verify_integrity=True) - def test_shift(self): pser = pd.Series([10, 20, 15, 30, 45], name="x") psser = ps.Series(pser) @@ -415,10 +407,6 @@ def test_abs(self): self.assert_eq(abs(psser), abs(pser)) self.assert_eq(np.abs(psser), np.abs(pser)) - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43550): Enable SeriesTests.test_factorize for pandas 2.0.0.", - ) def test_factorize(self): pser = pd.Series(["a", "b", "a", "b"]) psser = ps.from_pandas(pser) @@ -479,7 +467,7 @@ def test_factorize(self): pcodes, puniques = pser.factorize() kcodes, kuniques = psser.factorize() self.assert_eq(pcodes, kcodes.to_list()) - # pandas: Float64Index([], dtype='float64') + # pandas: Index([], dtype='float64') self.assert_eq(pd.Index([]), kuniques) pser = pd.Series([np.nan, np.nan]) @@ -487,7 +475,7 @@ def test_factorize(self): pcodes, puniques = pser.factorize() kcodes, kuniques = psser.factorize() self.assert_eq(pcodes, kcodes.to_list()) - # pandas: Float64Index([], dtype='float64') + # pandas: Index([], dtype='float64') self.assert_eq(pd.Index([]), kuniques) # @@ -500,27 +488,27 @@ def test_factorize(self): pser = pd.Series(["a", "b", "a", np.nan, None]) psser = ps.from_pandas(pser) - pcodes, puniques = pser.factorize(sort=True, na_sentinel=-2) - kcodes, kuniques = psser.factorize(na_sentinel=-2) + pcodes, puniques = pser.factorize(sort=True, use_na_sentinel=-2) + kcodes, kuniques = psser.factorize(use_na_sentinel=-2) self.assert_eq(pcodes.tolist(), kcodes.to_list()) self.assert_eq(puniques, kuniques) - pcodes, puniques = pser.factorize(sort=True, na_sentinel=2) - kcodes, kuniques = psser.factorize(na_sentinel=2) + pcodes, puniques = pser.factorize(sort=True, use_na_sentinel=2) + kcodes, kuniques = psser.factorize(use_na_sentinel=2) self.assert_eq(pcodes.tolist(), kcodes.to_list()) self.assert_eq(puniques, kuniques) if not pd_below_1_1_2: - pcodes, puniques = pser.factorize(sort=True, na_sentinel=None) - kcodes, kuniques = psser.factorize(na_sentinel=None) + pcodes, puniques = pser.factorize(sort=True, use_na_sentinel=None) + kcodes, kuniques = psser.factorize(use_na_sentinel=None) self.assert_eq(pcodes.tolist(), kcodes.to_list()) # puniques is Index(['a', 'b', nan], dtype='object') self.assert_eq(ps.Index(["a", "b", None]), kuniques) psser = ps.Series([1, 2, np.nan, 4, 5]) # Arrow takes np.nan as null psser.loc[3] = np.nan # Spark takes np.nan as NaN - kcodes, kuniques = psser.factorize(na_sentinel=None) - pcodes, puniques = psser._to_pandas().factorize(sort=True, na_sentinel=None) + kcodes, kuniques = psser.factorize(use_na_sentinel=None) + pcodes, puniques = psser._to_pandas().factorize(sort=True, use_na_sentinel=None) self.assert_eq(pcodes.tolist(), kcodes.to_list()) self.assert_eq(puniques, kuniques) diff --git a/python/pyspark/pandas/tests/series/test_series.py b/python/pyspark/pandas/tests/series/test_series.py index 116acb2a5b2b3..136d905eb494b 100644 --- a/python/pyspark/pandas/tests/series/test_series.py +++ b/python/pyspark/pandas/tests/series/test_series.py @@ -670,15 +670,11 @@ def test_filter(self): with self.assertRaisesRegex(ValueError, "The item should not be empty."): psser.filter(items=[(), ("three", "z")]) - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43480): Enable SeriesTests.test_iteritems for pandas 2.0.0.", - ) - def test_iteritems(self): + def test_items(self): pser = pd.Series(["A", "B", "C"]) psser = ps.from_pandas(pser) - for (p_name, p_items), (k_name, k_items) in zip(pser.iteritems(), psser.iteritems()): + for (p_name, p_items), (k_name, k_items) in zip(pser.items(), psser.items()): self.assert_eq(p_name, k_name) self.assert_eq(p_items, k_items) @@ -692,7 +688,8 @@ def test_dot(self): psdf_other = ps.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, index=["x", "y", "z"]) with self.assertRaisesRegex(ValueError, "matrices are not aligned"): - psdf["b"].dot(psdf_other) + with ps.option_context("compute.ops_on_diff_frames", True): + psdf["b"].dot(psdf_other) def test_tail(self): pser = pd.Series(range(1000), name="Koalas") diff --git a/python/pyspark/pandas/tests/series/test_stat.py b/python/pyspark/pandas/tests/series/test_stat.py index 0d6e242492149..048a4c94fd939 100644 --- a/python/pyspark/pandas/tests/series/test_stat.py +++ b/python/pyspark/pandas/tests/series/test_stat.py @@ -524,41 +524,6 @@ def test_div_zero_and_nan(self): self.assert_eq(pser // 0, psser // 0) self.assert_eq(pser.floordiv(np.nan), psser.floordiv(np.nan)) - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43468): Enable SeriesTests.test_mad for pandas 2.0.0.", - ) - def test_mad(self): - pser = pd.Series([1, 2, 3, 4], name="Koalas") - psser = ps.from_pandas(pser) - - self.assert_eq(pser.mad(), psser.mad()) - - pser = pd.Series([None, -2, 5, 10, 50, np.nan, -20], name="Koalas") - psser = ps.from_pandas(pser) - - self.assert_eq(pser.mad(), psser.mad()) - - pmidx = pd.MultiIndex.from_tuples( - [("a", "1"), ("a", "2"), ("b", "1"), ("b", "2"), ("c", "1")] - ) - pser = pd.Series([1, 2, 3, 4, 5], name="Koalas") - pser.index = pmidx - psser = ps.from_pandas(pser) - - self.assert_eq(pser.mad(), psser.mad()) - - pmidx = pd.MultiIndex.from_tuples( - [("a", "1"), ("a", "2"), ("b", "1"), ("b", "2"), ("c", "1")] - ) - pser = pd.Series([None, -2, 5, 50, np.nan], name="Koalas") - pser.index = pmidx - psser = ps.from_pandas(pser) - - # Mark almost as True to avoid precision issue like: - # "21.555555555555554 != 21.555555555555557" - self.assert_eq(pser.mad(), psser.mad(), almost=True) - @unittest.skipIf( LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), "TODO(SPARK-43481): Enable SeriesTests.test_product for pandas 2.0.0.", diff --git a/python/pyspark/pandas/tests/test_utils.py b/python/pyspark/pandas/tests/test_utils.py index 3d658446f2766..0bb03dd8749da 100644 --- a/python/pyspark/pandas/tests/test_utils.py +++ b/python/pyspark/pandas/tests/test_utils.py @@ -208,10 +208,10 @@ def test_series_error_assert_pandas_equal(self): exception=pe.exception, error_class="DIFFERENT_PANDAS_SERIES", message_parameters={ - "left": series1, - "left_dtype": series1.dtype, - "right": series2, - "right_dtype": series2.dtype, + "left": series1.to_string(), + "left_dtype": str(series1.dtype), + "right": series2.to_string(), + "right_dtype": str(series2.dtype), }, ) @@ -227,9 +227,9 @@ def test_index_error_assert_pandas_equal(self): error_class="DIFFERENT_PANDAS_INDEX", message_parameters={ "left": index1, - "left_dtype": index1.dtype, + "left_dtype": str(index1.dtype), "right": index2, - "right_dtype": index2.dtype, + "right_dtype": str(index2.dtype), }, ) @@ -247,9 +247,9 @@ def test_multiindex_error_assert_pandas_almost_equal(self): error_class="DIFFERENT_PANDAS_MULTIINDEX", message_parameters={ "left": multiindex1, - "left_dtype": multiindex1.dtype, + "left_dtype": str(multiindex1.dtype), "right": multiindex2, - "right_dtype": multiindex2.dtype, + "right_dtype": str(multiindex1.dtype), }, ) diff --git a/python/pyspark/pandas/usage_logging/__init__.py b/python/pyspark/pandas/usage_logging/__init__.py index e14a905e78a04..4478b6c85f662 100644 --- a/python/pyspark/pandas/usage_logging/__init__.py +++ b/python/pyspark/pandas/usage_logging/__init__.py @@ -29,7 +29,6 @@ from pyspark.pandas.indexes.category import CategoricalIndex from pyspark.pandas.indexes.datetimes import DatetimeIndex from pyspark.pandas.indexes.multi import MultiIndex -from pyspark.pandas.indexes.numeric import Float64Index, Int64Index from pyspark.pandas.missing.frame import MissingPandasLikeDataFrame from pyspark.pandas.missing.general_functions import MissingPandasLikeGeneralFunctions from pyspark.pandas.missing.groupby import ( @@ -89,8 +88,6 @@ def attach(logger_module: Union[str, ModuleType]) -> None: Series, Index, MultiIndex, - Int64Index, - Float64Index, CategoricalIndex, DatetimeIndex, DataFrameGroupBy, diff --git a/python/pyspark/pandas/utils.py b/python/pyspark/pandas/utils.py index c66b3359e77d1..55b9a57ef6187 100644 --- a/python/pyspark/pandas/utils.py +++ b/python/pyspark/pandas/utils.py @@ -478,12 +478,7 @@ def is_testing() -> bool: def default_session() -> SparkSession: - if not is_remote(): - spark = SparkSession.getActiveSession() - else: - from pyspark.sql.connect.session import _active_spark_session - - spark = _active_spark_session # type: ignore[assignment] + spark = SparkSession.getActiveSession() if spark is None: spark = SparkSession.builder.appName("pandas-on-Spark").getOrCreate() diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index a82c596555f8a..a7c3a92d3b1dc 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -65,7 +65,10 @@ from pyspark.version import __version__ from pyspark.resource.information import ResourceInformation from pyspark.sql.connect.client.artifact import ArtifactManager -from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator +from pyspark.sql.connect.client.reattach import ( + ExecutePlanResponseReattachableIterator, + RetryException, +) from pyspark.sql.connect.conversion import storage_level_to_proto, proto_to_storage_level import pyspark.sql.connect.proto as pb2 import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib @@ -1549,7 +1552,7 @@ def __exit__( ) -> Optional[bool]: if isinstance(exc_val, BaseException): # Swallow the exception. - if self._can_retry(exc_val): + if self._can_retry(exc_val) or isinstance(exc_val, RetryException): self._retry_state.set_exception(exc_val) return True # Bubble up the exception. diff --git a/python/pyspark/sql/connect/client/reattach.py b/python/pyspark/sql/connect/client/reattach.py index 4d4cce0ca4413..70c7d126ff105 100644 --- a/python/pyspark/sql/connect/client/reattach.py +++ b/python/pyspark/sql/connect/client/reattach.py @@ -21,10 +21,13 @@ import warnings import uuid from collections.abc import Generator -from typing import Optional, Dict, Any, Iterator, Iterable, Tuple +from typing import Optional, Dict, Any, Iterator, Iterable, Tuple, Callable, cast from multiprocessing.pool import ThreadPool import os +import grpc +from grpc_status import rpc_status + import pyspark.sql.connect.proto as pb2 import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib @@ -42,15 +45,12 @@ class ExecutePlanResponseReattachableIterator(Generator): Initial iterator is the result of an ExecutePlan on the request, but it can be reattached with ReattachExecute request. ReattachExecute request is provided the responseId of last returned ExecutePlanResponse on the iterator to return a new iterator from server that continues after - that. + that. If the initial ExecutePlan did not even reach the server, and hence reattach fails with + INVALID_HANDLE.OPERATION_NOT_FOUND, we attempt to retry ExecutePlan. In reattachable execute the server does buffer some responses in case the client needs to backtrack. To let server release this buffer sooner, this iterator asynchronously sends ReleaseExecute RPCs that instruct the server to release responses that it already processed. - - Note: If the initial ExecutePlan did not even reach the server and execution didn't start, - the ReattachExecute can still fail with INVALID_HANDLE.OPERATION_NOT_FOUND, failing the whole - operation. """ _release_thread_pool = ThreadPool(os.cpu_count() if os.cpu_count() else 8) @@ -93,6 +93,7 @@ def __init__( # Initial iterator comes from ExecutePlan request. # Note: This is not retried, because no error would ever be thrown here, and GRPC will only # throw error on first self._has_next(). + self._metadata = metadata self._iterator: Iterator[pb2.ExecutePlanResponse] = iter( self._stub.ExecutePlan(self._initial_request, metadata=metadata) ) @@ -111,9 +112,9 @@ def send(self, value: Any) -> pb2.ExecutePlanResponse: self._last_returned_response_id = ret.response_id if ret.HasField("result_complete"): self._result_complete = True - self._release_execute(None) # release all + self._release_all() else: - self._release_execute(self._last_returned_response_id) + self._release_until(self._last_returned_response_id) self._current = None return ret @@ -125,61 +126,93 @@ def _has_next(self) -> bool: # After response complete response return False else: - for attempt in Retrying( - can_retry=SparkConnectClient.retry_exception, **self._retry_policy - ): - with attempt: - # on first try, we use the existing iterator. - if not attempt.is_first_try(): - # on retry, the iterator is borked, so we need a new one - self._iterator = iter( - self._stub.ReattachExecute(self._create_reattach_execute_request()) - ) - - if self._current is None: - try: - self._current = next(self._iterator) - except StopIteration: - pass - - has_next = self._current is not None - - # Graceful reattach: - # If iterator ended, but there was no ResponseComplete, it means that - # there is more, and we need to reattach. While ResponseComplete didn't - # arrive, we keep reattaching. - if not self._result_complete and not has_next: - while not has_next: + try: + for attempt in Retrying( + can_retry=SparkConnectClient.retry_exception, **self._retry_policy + ): + with attempt: + # on first try, we use the existing iterator. + if not attempt.is_first_try(): + # on retry, the iterator is borked, so we need a new one self._iterator = iter( self._stub.ReattachExecute(self._create_reattach_execute_request()) ) - # shouldn't change - assert not self._result_complete + + if self._current is None: try: - self._current = next(self._iterator) + self._current = self._call_iter(lambda: next(self._iterator)) except StopIteration: pass - has_next = self._current is not None - return has_next + + has_next = self._current is not None + + # Graceful reattach: + # If iterator ended, but there was no ResponseComplete, it means that + # there is more, and we need to reattach. While ResponseComplete didn't + # arrive, we keep reattaching. + if not self._result_complete and not has_next: + while not has_next: + self._iterator = iter( + self._stub.ReattachExecute( + self._create_reattach_execute_request() + ) + ) + # shouldn't change + assert not self._result_complete + try: + self._current = self._call_iter(lambda: next(self._iterator)) + except StopIteration: + pass + has_next = self._current is not None + return has_next + except Exception as e: + self._release_all() + raise e return False - def _release_execute(self, until_response_id: Optional[str]) -> None: + def _release_until(self, until_response_id: str) -> None: """ - Inform the server to release the execution. + Inform the server to release the buffered execution results until and including given + result. This will send an asynchronous RPC which will not block this iterator, the iterator can continue to be consumed. + """ + if self._result_complete: + return - Release with untilResponseId informs the server that the iterator has been consumed until - and including response with that responseId, and these responses can be freed. + from pyspark.sql.connect.client.core import SparkConnectClient + from pyspark.sql.connect.client.core import Retrying + + request = self._create_release_execute_request(until_response_id) - Release with None means that the responses have been completely consumed and informs the - server that the completed execution can be completely freed. + def target() -> None: + try: + for attempt in Retrying( + can_retry=SparkConnectClient.retry_exception, **self._retry_policy + ): + with attempt: + self._stub.ReleaseExecute(request) + except Exception as e: + warnings.warn(f"ReleaseExecute failed with exception: {e}.") + + ExecutePlanResponseReattachableIterator._release_thread_pool.apply_async(target) + + def _release_all(self) -> None: """ + Inform the server to release the execution, either because all results were consumed, + or the execution finished with error and the error was received. + + This will send an asynchronous RPC which will not block this. The client continues + executing, and if the release fails, server is equipped to deal with abandoned executions. + """ + if self._result_complete: + return + from pyspark.sql.connect.client.core import SparkConnectClient from pyspark.sql.connect.client.core import Retrying - request = self._create_release_execute_request(until_response_id) + request = self._create_release_execute_request(None) def target() -> None: try: @@ -192,6 +225,34 @@ def target() -> None: warnings.warn(f"ReleaseExecute failed with exception: {e}.") ExecutePlanResponseReattachableIterator._release_thread_pool.apply_async(target) + self._result_complete = True + + def _call_iter(self, iter_fun: Callable) -> Any: + """ + Call next() on the iterator. If this fails with this operationId not existing + on the server, this means that the initial ExecutePlan request didn't even reach the + server. In that case, attempt to start again with ExecutePlan. + + Called inside retry block, so retryable failure will get handled upstream. + """ + try: + return iter_fun() + except grpc.RpcError as e: + status = rpc_status.from_call(cast(grpc.Call, e)) + if "INVALID_HANDLE.OPERATION_NOT_FOUND" in status.message: + if self._last_returned_response_id is not None: + raise RuntimeError( + "OPERATION_NOT_FOUND on the server but " + "responses were already received from it.", + e, + ) + # Try a new ExecutePlan, and throw upstream for retry. + self._iterator = iter( + self._stub.ExecutePlan(self._initial_request, metadata=self._metadata) + ) + raise RetryException() + else: + raise e def _create_reattach_execute_request(self) -> pb2.ReattachExecuteRequest: reattach = pb2.ReattachExecuteRequest( @@ -231,7 +292,15 @@ def throw(self, type: Any = None, value: Any = None, traceback: Any = None) -> A super().throw(type, value, traceback) def close(self) -> None: + self._release_all() return super().close() def __del__(self) -> None: return self.close() + + +class RetryException(Exception): + """ + An exception that can be thrown upstream when inside retry and which will be retryable + regardless of policy. + """ diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 58dffd93bf9b5..7da93ef413c20 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -21,6 +21,7 @@ from typing import Any, List, Optional, Type, Sequence, Union, cast, TYPE_CHECKING, Mapping, Dict import functools import json +import pickle from threading import Lock from inspect import signature, isclass @@ -40,7 +41,7 @@ LiteralExpression, ) from pyspark.sql.connect.types import pyspark_types_to_proto_types, UnparsedDataType -from pyspark.errors import PySparkTypeError, PySparkNotImplementedError +from pyspark.errors import PySparkTypeError, PySparkNotImplementedError, PySparkRuntimeError if TYPE_CHECKING: from pyspark.sql.connect._typing import ColumnOrName @@ -2202,7 +2203,17 @@ def to_plan(self, session: "SparkConnectClient") -> proto.PythonUDTF: if self._return_type is not None: udtf.return_type.CopyFrom(pyspark_types_to_proto_types(self._return_type)) udtf.eval_type = self._eval_type - udtf.command = CloudPickleSerializer().dumps(self._func) + try: + udtf.command = CloudPickleSerializer().dumps(self._func) + except pickle.PicklingError: + raise PySparkRuntimeError( + error_class="UDTF_SERIALIZATION_ERROR", + message_parameters={ + "name": self._name, + "message": "Please check the stack trace and " + "make sure the function is serializable.", + }, + ) udtf.python_ver = self._python_ver return udtf diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index e870221594c13..a886ecbd61842 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -2554,9 +2554,7 @@ class ReleaseExecuteRequest(google.protobuf.message.Message): class ReleaseAll(google.protobuf.message.Message): """Release and close operation completely. - Note: This should be called when the server side operation is finished, and ExecutePlan or - ReattachExecute are finished processing the result stream, or inside onComplete / onError. - This will not interrupt a running execution, but block until it's finished. + This will also interrupt the query if it is running execution, and wait for it to be torn down. """ DESCRIPTOR: google.protobuf.descriptor.Descriptor diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 9bba0db05e43f..d75a30c561f93 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -18,6 +18,7 @@ check_dependencies(__name__) +import threading import os import warnings from collections.abc import Sized @@ -36,6 +37,7 @@ overload, Iterable, TYPE_CHECKING, + ClassVar, ) import numpy as np @@ -93,14 +95,13 @@ from pyspark.sql.connect.udtf import UDTFRegistration -# `_active_spark_session` stores the active spark connect session created by -# `SparkSession.builder.getOrCreate`. It is used by ML code. -# If sessions are created with `SparkSession.builder.create`, it stores -# The last created session -_active_spark_session = None - - class SparkSession: + # The active SparkSession for the current thread + _active_session: ClassVar[threading.local] = threading.local() + # Reference to the root SparkSession + _default_session: ClassVar[Optional["SparkSession"]] = None + _lock: ClassVar[RLock] = RLock() + class Builder: """Builder for :class:`SparkSession`.""" @@ -176,8 +177,6 @@ def enableHiveSupport(self) -> "SparkSession.Builder": ) def create(self) -> "SparkSession": - global _active_spark_session - has_channel_builder = self._channel_builder is not None has_spark_remote = "spark.remote" in self._options @@ -200,23 +199,26 @@ def create(self) -> "SparkSession": assert spark_remote is not None session = SparkSession(connection=spark_remote) - _active_spark_session = session + SparkSession._set_default_and_active_session(session) return session def getOrCreate(self) -> "SparkSession": - global _active_spark_session - if _active_spark_session is not None: - return _active_spark_session - _active_spark_session = self.create() - return _active_spark_session + with SparkSession._lock: + session = SparkSession.getActiveSession() + if session is None: + session = SparkSession._default_session + if session is None: + session = self.create() + return session _client: SparkConnectClient @classproperty def builder(cls) -> Builder: - """Creates a :class:`Builder` for constructing a :class:`SparkSession`.""" return cls.Builder() + builder.__doc__ = PySparkSession.builder.__doc__ + def __init__(self, connection: Union[str, ChannelBuilder], userId: Optional[str] = None): """ Creates a new SparkSession for the Spark Connect interface. @@ -236,6 +238,38 @@ def __init__(self, connection: Union[str, ChannelBuilder], userId: Optional[str] self._client = SparkConnectClient(connection=connection, user_id=userId) self._session_id = self._client._session_id + @classmethod + def _set_default_and_active_session(cls, session: "SparkSession") -> None: + """ + Set the (global) default :class:`SparkSession`, and (thread-local) + active :class:`SparkSession` when they are not set yet. + """ + with cls._lock: + if cls._default_session is None: + cls._default_session = session + if getattr(cls._active_session, "session", None) is None: + cls._active_session.session = session + + @classmethod + def getActiveSession(cls) -> Optional["SparkSession"]: + return getattr(cls._active_session, "session", None) + + getActiveSession.__doc__ = PySparkSession.getActiveSession.__doc__ + + @classmethod + def active(cls) -> "SparkSession": + session = cls.getActiveSession() + if session is None: + session = cls._default_session + if session is None: + raise PySparkRuntimeError( + error_class="NO_ACTIVE_OR_DEFAULT_SESSION", + message_parameters={}, + ) + return session + + active.__doc__ = PySparkSession.active.__doc__ + def table(self, tableName: str) -> DataFrame: return self.read.table(tableName) @@ -251,6 +285,8 @@ def read(self) -> "DataFrameReader": def readStream(self) -> "DataStreamReader": return DataStreamReader(self) + readStream.__doc__ = PySparkSession.readStream.__doc__ + def _inferSchemaFromList( self, data: Iterable[Any], names: Optional[List[str]] = None ) -> StructType: @@ -601,19 +637,20 @@ def stop(self) -> None: # specifically in Spark Connect the Spark Connect server is designed for # multi-tenancy - the remote client side cannot just stop the server and stop # other remote clients being used from other users. - global _active_spark_session - self.client.close() - _active_spark_session = None - - if "SPARK_LOCAL_REMOTE" in os.environ: - # When local mode is in use, follow the regular Spark session's - # behavior by terminating the Spark Connect server, - # meaning that you can stop local mode, and restart the Spark Connect - # client with a different remote address. - active_session = PySparkSession.getActiveSession() - if active_session is not None: - active_session.stop() - with SparkContext._lock: + with SparkSession._lock: + self.client.close() + if self is SparkSession._default_session: + SparkSession._default_session = None + if self is getattr(SparkSession._active_session, "session", None): + SparkSession._active_session.session = None + + if "SPARK_LOCAL_REMOTE" in os.environ: + # When local mode is in use, follow the regular Spark session's + # behavior by terminating the Spark Connect server, + # meaning that you can stop local mode, and restart the Spark Connect + # client with a different remote address. + if PySparkSession._activeSession is not None: + PySparkSession._activeSession.stop() del os.environ["SPARK_LOCAL_REMOTE"] del os.environ["SPARK_CONNECT_MODE_ENABLED"] if "SPARK_REMOTE" in os.environ: @@ -628,20 +665,18 @@ def is_stopped(self) -> bool: """ return self.client.is_closed - @classmethod - def getActiveSession(cls) -> Any: - raise PySparkNotImplementedError( - error_class="NOT_IMPLEMENTED", message_parameters={"feature": "getActiveSession()"} - ) - @property def conf(self) -> RuntimeConf: return RuntimeConf(self.client) + conf.__doc__ = PySparkSession.conf.__doc__ + @property def streams(self) -> "StreamingQueryManager": return StreamingQueryManager(self) + streams.__doc__ = PySparkSession.streams.__doc__ + def __getattr__(self, name: str) -> Any: if name in ["_jsc", "_jconf", "_jvm", "_jsparkSession"]: raise PySparkAttributeError( @@ -675,6 +710,8 @@ def version(self) -> str: assert result is not None return result + version.__doc__ = PySparkSession.version.__doc__ + @property def client(self) -> "SparkConnectClient": return self._client diff --git a/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py index 054788539f293..48a9848de4009 100644 --- a/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py +++ b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py @@ -76,7 +76,9 @@ def process(df_id, batch_id): # type: ignore[no-untyped-def] # Read information about how to connect back to the JVM from the environment. java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] - (sock_file, _) = local_connect_and_auth(java_port, auth_secret) + (sock_file, sock) = local_connect_and_auth(java_port, auth_secret) + # There could be a long time between each micro batch. + sock.settimeout(None) write_int(os.getpid(), sock_file) sock_file.flush() main(sock_file, sock_file) diff --git a/python/pyspark/sql/connect/streaming/worker/listener_worker.py b/python/pyspark/sql/connect/streaming/worker/listener_worker.py index 8eb310461b6f6..7aef911426de7 100644 --- a/python/pyspark/sql/connect/streaming/worker/listener_worker.py +++ b/python/pyspark/sql/connect/streaming/worker/listener_worker.py @@ -89,7 +89,9 @@ def process(listener_event_str, listener_event_type): # type: ignore[no-untyped # Read information about how to connect back to the JVM from the environment. java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] - (sock_file, _) = local_connect_and_auth(java_port, auth_secret) + (sock_file, sock) = local_connect_and_auth(java_port, auth_secret) + # There could be a long time between each listener event. + sock.settimeout(None) write_int(os.getpid(), sock_file) sock_file.flush() main(sock_file, sock_file) diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index 2d7e423d3d571..eb0541b936925 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -37,8 +37,7 @@ from pyspark.sql.connect.types import UnparsedDataType from pyspark.sql.types import DataType, StringType from pyspark.sql.udf import UDFRegistration as PySparkUDFRegistration -from pyspark.errors import PySparkTypeError - +from pyspark.errors import PySparkTypeError, PySparkRuntimeError if TYPE_CHECKING: from pyspark.sql.connect._typing import ( @@ -58,14 +57,20 @@ def _create_py_udf( from pyspark.sql.udf import _create_arrow_py_udf if useArrow is None: - from pyspark.sql.connect.session import _active_spark_session - - is_arrow_enabled = ( - False - if _active_spark_session is None - else _active_spark_session.conf.get("spark.sql.execution.pythonUDF.arrow.enabled") - == "true" - ) + is_arrow_enabled = False + try: + from pyspark.sql.connect.session import SparkSession + + session = SparkSession.active() + is_arrow_enabled = ( + str(session.conf.get("spark.sql.execution.pythonUDF.arrow.enabled")).lower() + == "true" + ) + except PySparkRuntimeError as e: + if e.error_class == "NO_ACTIVE_OR_DEFAULT_SESSION": + pass # Just uses the default if no session found. + else: + raise e else: is_arrow_enabled = useArrow diff --git a/python/pyspark/sql/connect/udtf.py b/python/pyspark/sql/connect/udtf.py index 919994401c802..c8495626292c5 100644 --- a/python/pyspark/sql/connect/udtf.py +++ b/python/pyspark/sql/connect/udtf.py @@ -68,13 +68,20 @@ def _create_py_udtf( if useArrow is not None: arrow_enabled = useArrow else: - from pyspark.sql.connect.session import _active_spark_session - - arrow_enabled = ( - _active_spark_session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled") == "true" - if _active_spark_session is not None - else True - ) + from pyspark.sql.connect.session import SparkSession + + arrow_enabled = False + try: + session = SparkSession.active() + arrow_enabled = ( + str(session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled")).lower() + == "true" + ) + except PySparkRuntimeError as e: + if e.error_class == "NO_ACTIVE_OR_DEFAULT_SESSION": + pass # Just uses the default if no session found. + else: + raise e # Create a regular Python UDTF and check for invalid handler class. regular_udtf = _create_udtf(cls, returnType, name, PythonEvalType.SQL_TABLE_UDF, deterministic) @@ -160,17 +167,13 @@ def _build_common_inline_user_defined_table_function( ) def __call__(self, *cols: "ColumnOrName") -> "DataFrame": + from pyspark.sql.connect.session import SparkSession from pyspark.sql.connect.dataframe import DataFrame - from pyspark.sql.connect.session import _active_spark_session - if _active_spark_session is None: - raise PySparkRuntimeError( - "An active SparkSession is required for " - "executing a Python user-defined table function." - ) + session = SparkSession.active() plan = self._build_common_inline_user_defined_table_function(*cols) - return DataFrame.withPlan(plan, _active_spark_session) + return DataFrame.withPlan(plan, session) def asNondeterministic(self) -> "UserDefinedTableFunction": self.deterministic = False diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index f3037c8b39c86..d1a3babb1fdc0 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -571,7 +571,10 @@ def _create_array(self, series, arrow_type, spark_type=None, arrow_cast=False): dt = spark_type or from_arrow_type(arrow_type, prefer_timestamp_ntz=True) # TODO(SPARK-43579): cache the converter for reuse conv = _create_converter_from_pandas( - dt, timezone=self._timezone, error_on_duplicated_field_names=False + dt, + timezone=self._timezone, + error_on_duplicated_field_names=False, + ignore_unexpected_complex_type_values=True, ) series = conv(series) diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index 533620476041a..b02a003e632cb 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -21,7 +21,7 @@ """ import datetime import itertools -from typing import Any, Callable, List, Optional, Union, TYPE_CHECKING +from typing import Any, Callable, Iterable, List, Optional, Union, TYPE_CHECKING from pyspark.sql.types import ( cast, @@ -750,6 +750,7 @@ def _create_converter_from_pandas( *, timezone: Optional[str], error_on_duplicated_field_names: bool = True, + ignore_unexpected_complex_type_values: bool = False, ) -> Callable[["pd.Series"], "pd.Series"]: """ Create a converter of pandas Series to create Spark DataFrame with Arrow optimization. @@ -763,6 +764,17 @@ def _create_converter_from_pandas( error_on_duplicated_field_names : bool, optional Whether raise an exception when there are duplicated field names. (default ``True``) + ignore_unexpected_complex_type_values : bool, optional + Whether ignore the case where unexpected values are given for complex types. + If ``False``, each complex type expects: + + * array type: :class:`Iterable` + * map type: :class:`dict` + * struct type: :class:`dict` or :class:`tuple` + + and raise an AssertionError when the given value is not the expected type. + If ``True``, just ignore and return the give value. + (default ``False``) Returns ------- @@ -781,15 +793,26 @@ def correct_timestamp(pser: pd.Series) -> pd.Series: def _converter(dt: DataType) -> Optional[Callable[[Any], Any]]: if isinstance(dt, ArrayType): - _element_conv = _converter(dt.elementType) - if _element_conv is None: - return None + _element_conv = _converter(dt.elementType) or (lambda x: x) - def convert_array(value: Any) -> Any: - if value is None: - return None - else: - return [_element_conv(v) for v in value] # type: ignore[misc] + if ignore_unexpected_complex_type_values: + + def convert_array(value: Any) -> Any: + if value is None: + return None + elif isinstance(value, Iterable): + return [_element_conv(v) for v in value] + else: + return value + + else: + + def convert_array(value: Any) -> Any: + if value is None: + return None + else: + assert isinstance(value, Iterable) + return [_element_conv(v) for v in value] return convert_array @@ -797,12 +820,24 @@ def convert_array(value: Any) -> Any: _key_conv = _converter(dt.keyType) or (lambda x: x) _value_conv = _converter(dt.valueType) or (lambda x: x) - def convert_map(value: Any) -> Any: - if value is None: - return None - else: - assert isinstance(value, dict) - return [(_key_conv(k), _value_conv(v)) for k, v in value.items()] + if ignore_unexpected_complex_type_values: + + def convert_map(value: Any) -> Any: + if value is None: + return None + elif isinstance(value, dict): + return [(_key_conv(k), _value_conv(v)) for k, v in value.items()] + else: + return value + + else: + + def convert_map(value: Any) -> Any: + if value is None: + return None + else: + assert isinstance(value, dict) + return [(_key_conv(k), _value_conv(v)) for k, v in value.items()] return convert_map @@ -820,17 +855,38 @@ def convert_map(value: Any) -> Any: field_convs = [_converter(f.dataType) or (lambda x: x) for f in dt.fields] - def convert_struct(value: Any) -> Any: - if value is None: - return None - elif isinstance(value, dict): - return { - dedup_field_names[i]: field_convs[i](value.get(key, None)) - for i, key in enumerate(field_names) - } - else: - assert isinstance(value, tuple) - return {dedup_field_names[i]: field_convs[i](v) for i, v in enumerate(value)} + if ignore_unexpected_complex_type_values: + + def convert_struct(value: Any) -> Any: + if value is None: + return None + elif isinstance(value, dict): + return { + dedup_field_names[i]: field_convs[i](value.get(key, None)) + for i, key in enumerate(field_names) + } + elif isinstance(value, tuple): + return { + dedup_field_names[i]: field_convs[i](v) for i, v in enumerate(value) + } + else: + return value + + else: + + def convert_struct(value: Any) -> Any: + if value is None: + return None + elif isinstance(value, dict): + return { + dedup_field_names[i]: field_convs[i](value.get(key, None)) + for i, key in enumerate(field_names) + } + else: + assert isinstance(value, tuple) + return { + dedup_field_names[i]: field_convs[i](v) for i, v in enumerate(value) + } return convert_struct diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index ede6318782e0a..9141051fdf830 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -64,8 +64,8 @@ _from_numpy_type, ) from pyspark.errors.exceptions.captured import install_exception_handler -from pyspark.sql.utils import is_timestamp_ntz_preferred, to_str -from pyspark.errors import PySparkValueError, PySparkTypeError +from pyspark.sql.utils import is_timestamp_ntz_preferred, to_str, try_remote_session_classmethod +from pyspark.errors import PySparkValueError, PySparkTypeError, PySparkRuntimeError if TYPE_CHECKING: from pyspark.sql._typing import AtomicValue, RowLike, OptionalPrimitiveType @@ -500,7 +500,7 @@ def getOrCreate(self) -> "SparkSession": ).applyModifiableSettings(session._jsparkSession, self._options) return session - # SparkConnect-specific API + # Spark Connect-specific API def create(self) -> "SparkSession": """Creates a new SparkSession. Can only be used in the context of Spark Connect and will throw an exception otherwise. @@ -510,6 +510,10 @@ def create(self) -> "SparkSession": Returns ------- :class:`SparkSession` + + Notes + ----- + This method will update the default and/or active session if they are not set. """ opts = dict(self._options) if "SPARK_REMOTE" in os.environ or "spark.remote" in opts: @@ -546,7 +550,11 @@ def create(self) -> "SparkSession": # to Python 3.9.6 (https://github.com/python/cpython/pull/28838) @classproperty def builder(cls) -> Builder: - """Creates a :class:`Builder` for constructing a :class:`SparkSession`.""" + """Creates a :class:`Builder` for constructing a :class:`SparkSession`. + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + """ return cls.Builder() _instantiatedSession: ClassVar[Optional["SparkSession"]] = None @@ -632,12 +640,16 @@ def newSession(self) -> "SparkSession": return self.__class__(self._sc, self._jsparkSession.newSession()) @classmethod + @try_remote_session_classmethod def getActiveSession(cls) -> Optional["SparkSession"]: """ Returns the active :class:`SparkSession` for the current thread, returned by the builder .. versionadded:: 3.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Returns ------- :class:`SparkSession` @@ -667,6 +679,30 @@ def getActiveSession(cls) -> Optional["SparkSession"]: else: return None + @classmethod + @try_remote_session_classmethod + def active(cls) -> "SparkSession": + """ + Returns the active or default :class:`SparkSession` for the current thread, returned by + the builder. + + .. versionadded:: 3.5.0 + + Returns + ------- + :class:`SparkSession` + Spark session if an active or default session exists for the current thread. + """ + session = cls.getActiveSession() + if session is None: + session = cls._instantiatedSession + if session is None: + raise PySparkRuntimeError( + error_class="NO_ACTIVE_OR_DEFAULT_SESSION", + message_parameters={}, + ) + return session + @property def sparkContext(self) -> SparkContext: """ @@ -698,6 +734,9 @@ def version(self) -> str: .. versionadded:: 2.0.0 + .. versionchanged:: 3.4.0 + Supports Spark Connect. + Returns ------- str @@ -719,6 +758,9 @@ def conf(self) -> RuntimeConfig: .. versionadded:: 2.0.0 + .. versionchanged:: 3.4.0 + Supports Spark Connect. + Returns ------- :class:`pyspark.sql.conf.RuntimeConfig` @@ -726,7 +768,7 @@ def conf(self) -> RuntimeConfig: Examples -------- >>> spark.conf - + Set a runtime configuration for the session @@ -805,6 +847,9 @@ def udtf(self) -> "UDTFRegistration": .. versionadded:: 3.5.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Returns ------- :class:`UDTFRegistration` @@ -1639,6 +1684,9 @@ def readStream(self) -> DataStreamReader: .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Notes ----- This API is evolving. @@ -1650,7 +1698,7 @@ def readStream(self) -> DataStreamReader: Examples -------- >>> spark.readStream - + The example below uses Rate source that generates rows continuously. After that, we operate a modulo by 3, and then write the stream out to the console. @@ -1672,6 +1720,9 @@ def streams(self) -> "StreamingQueryManager": .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Notes ----- This API is evolving. @@ -1683,7 +1734,7 @@ def streams(self) -> "StreamingQueryManager": Examples -------- >>> spark.streams - + Get the list of active streaming queries diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py index 547462d4da6d5..4bf58bf7807b3 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py @@ -60,6 +60,10 @@ def test_listener_events(self): try: self.spark.streams.addListener(test_listener) + # This ensures the read socket on the server won't crash (i.e. because of timeout) + # when there hasn't been a new event for a long time + time.sleep(30) + df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() q = df.writeStream.format("noop").queryName("test").start() @@ -76,6 +80,9 @@ def test_listener_events(self): finally: self.spark.streams.removeListener(test_listener) + # Remove again to verify this won't throw any error + self.spark.streams.removeListener(test_listener) + if __name__ == "__main__": import unittest diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 065f1585a9f06..0687fc9f31331 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -3043,9 +3043,6 @@ def test_unsupported_functions(self): def test_unsupported_session_functions(self): # SPARK-41934: Disable unsupported functions. - with self.assertRaises(NotImplementedError): - RemoteSparkSession.getActiveSession() - with self.assertRaises(NotImplementedError): RemoteSparkSession.builder.enableHiveSupport() @@ -3331,6 +3328,7 @@ def test_error_stack_trace(self): spark.stop() def test_can_create_multiple_sessions_to_different_remotes(self): + self.spark.stop() self.assertIsNotNone(self.spark._client) # Creates a new remote session. other = PySparkSession.builder.remote("sc://other.remote:114/").create() diff --git a/python/pyspark/sql/tests/connect/test_parity_udtf.py b/python/pyspark/sql/tests/connect/test_parity_udtf.py index 748b611e66707..e12e697e582da 100644 --- a/python/pyspark/sql/tests/connect/test_parity_udtf.py +++ b/python/pyspark/sql/tests/connect/test_parity_udtf.py @@ -45,6 +45,9 @@ def tearDownClass(cls): # TODO: use PySpark error classes instead of SparkConnectGrpcException + def test_struct_output_type_casting_row(self): + self.check_struct_output_type_casting_row(SparkConnectGrpcException) + def test_udtf_with_invalid_return_type(self): @udtf(returnType="int") class TestUDTF: diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 083aa151d0dd8..7cb13693a0df9 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -1323,6 +1323,15 @@ def test_row_without_column_name(self): # test __repr__ with unicode values self.assertEqual(repr(Row("数", "量")), "") + # SPARK-44643: test __repr__ with empty Row + def test_row_repr_with_empty_row(self): + self.assertEqual(repr(Row(a=Row())), "Row(a=)") + self.assertEqual(repr(Row(Row())), ")>") + + EmptyRow = Row() + self.assertEqual(repr(Row(a=EmptyRow())), "Row(a=Row())") + self.assertEqual(repr(Row(EmptyRow())), "") + def test_empty_row(self): row = Row() self.assertEqual(len(row), 0) diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 65184549573dc..300067716e9de 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -18,14 +18,16 @@ import shutil import tempfile import unittest - from typing import Iterator +from py4j.protocol import Py4JJavaError + from pyspark.errors import ( PySparkAttributeError, PythonException, PySparkTypeError, AnalysisException, + PySparkRuntimeError, ) from pyspark.files import SparkFiles from pyspark.rdd import PythonEvalType @@ -161,24 +163,30 @@ def eval(self, a: int, b: int): self.assertEqual(rows, [Row(a=1, b=2), Row(a=2, b=3)]) def test_udtf_eval_returning_non_tuple(self): + @udtf(returnType="a: int") class TestUDTF: def eval(self, a: int): yield a - func = udtf(TestUDTF, returnType="a: int") - # TODO(SPARK-44005): improve this error message - with self.assertRaisesRegex(PythonException, "Unexpected tuple 1 with StructType"): - func(lit(1)).collect() + with self.assertRaisesRegex(PythonException, "UDTF_INVALID_OUTPUT_ROW_TYPE"): + TestUDTF(lit(1)).collect() - def test_udtf_eval_returning_non_generator(self): + @udtf(returnType="a: int") class TestUDTF: def eval(self, a: int): return (a,) - func = udtf(TestUDTF, returnType="a: int") - # TODO(SPARK-44005): improve this error message - with self.assertRaisesRegex(PythonException, "Unexpected tuple 1 with StructType"): - func(lit(1)).collect() + with self.assertRaisesRegex(PythonException, "UDTF_INVALID_OUTPUT_ROW_TYPE"): + TestUDTF(lit(1)).collect() + + def test_udtf_with_invalid_return_value(self): + @udtf(returnType="x: int") + class TestUDTF: + def eval(self, a): + return a + + with self.assertRaisesRegex(PythonException, "UDTF_RETURN_NOT_ITERABLE"): + TestUDTF(lit(1)).collect() def test_udtf_eval_with_no_return(self): @udtf(returnType="a: int") @@ -375,6 +383,35 @@ def terminate(self): ], ) + def test_init_with_exception(self): + @udtf(returnType="x: int") + class TestUDTF: + def __init__(self): + raise Exception("error") + + def eval(self): + yield 1, + + with self.assertRaisesRegex( + PythonException, + r"\[UDTF_EXEC_ERROR\] User defined table function encountered an error " + r"in the '__init__' method: error", + ): + TestUDTF().show() + + def test_eval_with_exception(self): + @udtf(returnType="x: int") + class TestUDTF: + def eval(self): + raise Exception("error") + + with self.assertRaisesRegex( + PythonException, + r"\[UDTF_EXEC_ERROR\] User defined table function encountered an error " + r"in the 'eval' method: error", + ): + TestUDTF().show() + def test_terminate_with_exceptions(self): @udtf(returnType="a: int, b: int") class TestUDTF: @@ -386,8 +423,8 @@ def terminate(self): with self.assertRaisesRegex( PythonException, - "User defined table function encountered an error in the 'terminate' " - "method: terminate error", + r"\[UDTF_EXEC_ERROR\] User defined table function encountered an error " + r"in the 'terminate' method: terminate error", ): TestUDTF(lit(1)).collect() @@ -543,12 +580,14 @@ def eval(self): assertDataFrameEqual(TestUDTF(), [Row()]) - def _check_result_or_exception(self, func_handler, ret_type, expected): + def _check_result_or_exception( + self, func_handler, ret_type, expected, *, err_type=PythonException + ): func = udtf(func_handler, returnType=ret_type) if not isinstance(expected, str): assertDataFrameEqual(func(), expected) else: - with self.assertRaisesRegex(PythonException, expected): + with self.assertRaisesRegex(err_type, expected): func().collect() def test_numeric_output_type_casting(self): @@ -640,20 +679,129 @@ def eval(self): def test_array_output_type_casting(self): class TestUDTF: def eval(self): - yield [1, 2], + yield [0, 1.1, 2], for ret_type, expected in [ + ("x: boolean", [Row(x=None)]), + ("x: tinyint", [Row(x=None)]), + ("x: smallint", [Row(x=None)]), ("x: int", [Row(x=None)]), - ("x: array", [Row(x=[1, 2])]), - ("x: array", [Row(x=[None, None])]), - ("x: array", [Row(x=["1", "2"])]), - ("x: array", [Row(x=[None, None])]), - ("x: array>", [Row(x=[None, None])]), + ("x: bigint", [Row(x=None)]), + ("x: string", [Row(x="[0, 1.1, 2]")]), + ("x: date", "AttributeError"), + ("x: timestamp", "AttributeError"), + ("x: byte", [Row(x=None)]), + ("x: binary", [Row(x=None)]), + ("x: float", [Row(x=None)]), + ("x: double", [Row(x=None)]), + ("x: decimal(10, 0)", [Row(x=None)]), + ("x: array", [Row(x=[0, None, 2])]), + ("x: array", [Row(x=[None, 1.1, None])]), + ("x: array", [Row(x=["0", "1.1", "2"])]), + ("x: array", [Row(x=[None, None, None])]), + ("x: array>", [Row(x=[None, None, None])]), ("x: map", [Row(x=None)]), + ("x: struct", [Row(x=Row(a=0, b=None, c=2))]), ]: with self.subTest(ret_type=ret_type): self._check_result_or_exception(TestUDTF, ret_type, expected) + def test_map_output_type_casting(self): + class TestUDTF: + def eval(self): + yield {"a": 0, "b": 1.1, "c": 2}, + + for ret_type, expected in [ + ("x: boolean", [Row(x=None)]), + ("x: tinyint", [Row(x=None)]), + ("x: smallint", [Row(x=None)]), + ("x: int", [Row(x=None)]), + ("x: bigint", [Row(x=None)]), + ("x: string", [Row(x="{a=0, b=1.1, c=2}")]), + ("x: date", "AttributeError"), + ("x: timestamp", "AttributeError"), + ("x: byte", [Row(x=None)]), + ("x: binary", [Row(x=None)]), + ("x: float", [Row(x=None)]), + ("x: double", [Row(x=None)]), + ("x: decimal(10, 0)", [Row(x=None)]), + ("x: array", [Row(x=None)]), + ("x: map", [Row(x={"a": "0", "b": "1.1", "c": "2"})]), + ("x: map", [Row(x={"a": None, "b": None, "c": None})]), + ("x: map", [Row(x={"a": 0, "b": None, "c": 2})]), + ("x: map", [Row(x={"a": None, "b": 1.1, "c": None})]), + ("x: map>", [Row(x={"a": None, "b": None, "c": None})]), + ("x: struct", [Row(x=Row(a=0))]), + ]: + with self.subTest(ret_type=ret_type): + self._check_result_or_exception(TestUDTF, ret_type, expected) + + def test_struct_output_type_casting_dict(self): + class TestUDTF: + def eval(self): + yield {"a": 0, "b": 1.1, "c": 2}, + + for ret_type, expected in [ + ("x: boolean", [Row(x=None)]), + ("x: tinyint", [Row(x=None)]), + ("x: smallint", [Row(x=None)]), + ("x: int", [Row(x=None)]), + ("x: bigint", [Row(x=None)]), + ("x: string", [Row(x="{a=0, b=1.1, c=2}")]), + ("x: date", "AttributeError"), + ("x: timestamp", "AttributeError"), + ("x: byte", [Row(x=None)]), + ("x: binary", [Row(x=None)]), + ("x: float", [Row(x=None)]), + ("x: double", [Row(x=None)]), + ("x: decimal(10, 0)", [Row(x=None)]), + ("x: array", [Row(x=None)]), + ("x: map", [Row(x={"a": "0", "b": "1.1", "c": "2"})]), + ("x: struct", [Row(Row(a="0", b="1.1", c="2"))]), + ("x: struct", [Row(Row(a=0, b=None, c=2))]), + ("x: struct", [Row(Row(a=None, b=1.1, c=None))]), + ]: + with self.subTest(ret_type=ret_type): + self._check_result_or_exception(TestUDTF, ret_type, expected) + + def test_struct_output_type_casting_row(self): + self.check_struct_output_type_casting_row(Py4JJavaError) + + def check_struct_output_type_casting_row(self, error_type): + class TestUDTF: + def eval(self): + yield Row(a=0, b=1.1, c=2), + + err = ("PickleException", error_type) + + for ret_type, expected in [ + ("x: boolean", err), + ("x: tinyint", err), + ("x: smallint", err), + ("x: int", err), + ("x: bigint", err), + ("x: string", err), + ("x: date", "ValueError"), + ("x: timestamp", "ValueError"), + ("x: byte", err), + ("x: binary", err), + ("x: float", err), + ("x: double", err), + ("x: decimal(10, 0)", err), + ("x: array", err), + ("x: map", err), + ("x: struct", [Row(Row(a="0", b="1.1", c="2"))]), + ("x: struct", [Row(Row(a=0, b=None, c=2))]), + ("x: struct", [Row(Row(a=None, b=1.1, c=None))]), + ]: + with self.subTest(ret_type=ret_type): + if isinstance(expected, tuple): + self._check_result_or_exception( + TestUDTF, ret_type, expected[0], err_type=expected[1] + ) + else: + self._check_result_or_exception(TestUDTF, ret_type, expected) + def test_inconsistent_output_types(self): class TestUDTF: def eval(self): @@ -702,6 +850,32 @@ def upper(s: str): }, ) + def test_udtf_pickle_error(self): + with tempfile.TemporaryDirectory() as d: + file = os.path.join(d, "file.txt") + file_obj = open(file, "w") + + @udtf(returnType="x: int") + class TestUDTF: + def eval(self): + file_obj + yield 1, + + with self.assertRaisesRegex(PySparkRuntimeError, "UDTF_SERIALIZATION_ERROR"): + TestUDTF().collect() + + def test_udtf_access_spark_session(self): + df = self.spark.range(10) + + @udtf(returnType="x: int") + class TestUDTF: + def eval(self): + df.collect() + yield 1, + + with self.assertRaisesRegex(PySparkRuntimeError, "UDTF_SERIALIZATION_ERROR"): + TestUDTF().collect() + def test_udtf_no_eval(self): with self.assertRaises(PySparkAttributeError) as e: @@ -1658,22 +1832,37 @@ def eval(self, x: str): PythonEvalType.SQL_ARROW_TABLE_UDF, ) + def test_udtf_arrow_sql_conf(self): + class TestUDTF: + def eval(self): + yield 1, + + # We do not use `self.sql_conf` here to test the SQL SET command + # instead of using PySpark's `spark.conf.set`. + old_value = self.spark.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled") + self.spark.sql("SET spark.sql.execution.pythonUDTF.arrow.enabled=False") + self.assertEqual(udtf(TestUDTF, returnType="x: int").evalType, PythonEvalType.SQL_TABLE_UDF) + self.spark.sql("SET spark.sql.execution.pythonUDTF.arrow.enabled=True") + self.assertEqual( + udtf(TestUDTF, returnType="x: int").evalType, PythonEvalType.SQL_ARROW_TABLE_UDF + ) + self.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", old_value) + def test_udtf_eval_returning_non_tuple(self): + @udtf(returnType="a: int") class TestUDTF: def eval(self, a: int): yield a - func = udtf(TestUDTF, returnType="a: int") # When arrow is enabled, it can handle non-tuple return value. - self.assertEqual(func(lit(1)).collect(), [Row(a=1)]) + assertDataFrameEqual(TestUDTF(lit(1)), [Row(a=1)]) - def test_udtf_eval_returning_non_generator(self): + @udtf(returnType="a: int") class TestUDTF: def eval(self, a: int): - return (a,) + return [a] - func = udtf(TestUDTF, returnType="a: int") - self.assertEqual(func(lit(1)).collect(), [Row(a=1)]) + assertDataFrameEqual(TestUDTF(lit(1)), [Row(a=1)]) def test_numeric_output_type_casting(self): class TestUDTF: @@ -1696,9 +1885,8 @@ def eval(self): ("x: double", [Row(x=1.0)]), ("x: decimal(10, 0)", err), ("x: array", err), - # TODO(SPARK-44561): fix AssertionError in convert_map and convert_struct - # ("x: map", None), - # ("x: struct", None) + ("x: map", err), + ("x: struct", err), ]: with self.subTest(ret_type=ret_type): self._check_result_or_exception(TestUDTF, ret_type, expected) @@ -1725,10 +1913,9 @@ def eval(self): ("x: double", [Row(x=1.0)]), ("x: decimal(10, 0)", [Row(x=1)]), ("x: array", [Row(x=["1"])]), - ("x: array", err), - # TODO(SPARK-44561): fix AssertionError in convert_map and convert_struct - # ("x: map", None), - # ("x: struct", None) + ("x: array", [Row(x=[1])]), + ("x: map", err), + ("x: struct", err), ]: with self.subTest(ret_type=ret_type): self._check_result_or_exception(TestUDTF, ret_type, expected) @@ -1756,9 +1943,8 @@ def eval(self): ("x: decimal(10, 0)", err), ("x: array", [Row(x=["h", "e", "l", "l", "o"])]), ("x: array", err), - # TODO(SPARK-44561): fix AssertionError in convert_map and convert_struct - # ("x: map", None), - # ("x: struct", None) + ("x: map", err), + ("x: struct", err), ]: with self.subTest(ret_type=ret_type): self._check_result_or_exception(TestUDTF, ret_type, expected) @@ -1789,9 +1975,103 @@ def eval(self): ("x: array", [Row(x=[0, 1, 2])]), ("x: array", [Row(x=[0, 1.1, 2])]), ("x: array>", err), - # TODO(SPARK-44561): fix AssertionError in convert_map and convert_struct - # ("x: map", None), - # ("x: struct", None) + ("x: map", err), + ("x: struct", err), + ("x: struct", err), + ]: + with self.subTest(ret_type=ret_type): + self._check_result_or_exception(TestUDTF, ret_type, expected) + + def test_map_output_type_casting(self): + class TestUDTF: + def eval(self): + yield {"a": 0, "b": 1.1, "c": 2}, + + err = "UDTF_ARROW_TYPE_CAST_ERROR" + + for ret_type, expected in [ + ("x: boolean", err), + ("x: tinyint", err), + ("x: smallint", err), + ("x: int", err), + ("x: bigint", err), + ("x: string", err), + ("x: date", err), + ("x: timestamp", err), + ("x: byte", err), + ("x: binary", err), + ("x: float", err), + ("x: double", err), + ("x: decimal(10, 0)", err), + ("x: array", [Row(x=["a", "b", "c"])]), + ("x: map", err), + ("x: map", err), + ("x: map", [Row(x={"a": 0, "b": 1, "c": 2})]), + ("x: map", [Row(x={"a": 0, "b": 1.1, "c": 2})]), + ("x: map>", err), + ("x: struct", [Row(x=Row(a=0))]), + ]: + with self.subTest(ret_type=ret_type): + self._check_result_or_exception(TestUDTF, ret_type, expected) + + def test_struct_output_type_casting_dict(self): + class TestUDTF: + def eval(self): + yield {"a": 0, "b": 1.1, "c": 2}, + + err = "UDTF_ARROW_TYPE_CAST_ERROR" + + for ret_type, expected in [ + ("x: boolean", err), + ("x: tinyint", err), + ("x: smallint", err), + ("x: int", err), + ("x: bigint", err), + ("x: string", err), + ("x: date", err), + ("x: timestamp", err), + ("x: byte", err), + ("x: binary", err), + ("x: float", err), + ("x: double", err), + ("x: decimal(10, 0)", err), + ("x: array", [Row(x=["a", "b", "c"])]), + ("x: map", err), + ("x: struct", [Row(Row(a="0", b="1.1", c="2"))]), + ("x: struct", [Row(Row(a=0, b=1, c=2))]), + ("x: struct", [Row(Row(a=0, b=1.1, c=2))]), + ("x: struct,b:struct<>,c:struct<>>", err), + ]: + with self.subTest(ret_type=ret_type): + self._check_result_or_exception(TestUDTF, ret_type, expected) + + def test_struct_output_type_casting_row(self): + class TestUDTF: + def eval(self): + yield Row(a=0, b=1.1, c=2), + + err = "UDTF_ARROW_TYPE_CAST_ERROR" + + for ret_type, expected in [ + ("x: boolean", err), + ("x: tinyint", err), + ("x: smallint", err), + ("x: int", err), + ("x: bigint", err), + ("x: string", err), + ("x: date", err), + ("x: timestamp", err), + ("x: byte", err), + ("x: binary", err), + ("x: float", err), + ("x: double", err), + ("x: decimal(10, 0)", err), + ("x: array", [Row(x=["0", "1.1", "2"])]), + ("x: map", err), + ("x: struct", [Row(Row(a="0", b="1.1", c="2"))]), + ("x: struct", [Row(Row(a=0, b=1, c=2))]), + ("x: struct", [Row(Row(a=0, b=1.1, c=2))]), + ("x: struct,b:struct<>,c:struct<>>", err), ]: with self.subTest(ret_type=ret_type): self._check_result_or_exception(TestUDTF, ret_type, expected) diff --git a/python/pyspark/sql/tests/test_utils.py b/python/pyspark/sql/tests/test_utils.py index 76d397e3adeb8..93895465de7f7 100644 --- a/python/pyspark/sql/tests/test_utils.py +++ b/python/pyspark/sql/tests/test_utils.py @@ -41,6 +41,7 @@ BooleanType, ) from pyspark.sql.dataframe import DataFrame +import pyspark.pandas as ps import difflib from typing import List, Union @@ -672,9 +673,79 @@ def test_assert_equal_nulldf(self): assertDataFrameEqual(df1, df2, checkRowOrder=False) assertDataFrameEqual(df1, df2, checkRowOrder=True) - def test_assert_equal_exact_pandas_df(self): - import pyspark.pandas as ps + def test_assert_unequal_null_actual(self): + df1 = None + df2 = self.spark.createDataFrame( + data=[ + ("1", 1000), + ("2", 3000), + ], + schema=["id", "amount"], + ) + + with self.assertRaises(PySparkAssertionError) as pe: + assertDataFrameEqual(df1, df2) + + self.check_error( + exception=pe.exception, + error_class="INVALID_TYPE_DF_EQUALITY_ARG", + message_parameters={ + "expected_type": Union[DataFrame, ps.DataFrame, List[Row]], + "arg_name": "actual", + "actual_type": None, + }, + ) + + with self.assertRaises(PySparkAssertionError) as pe: + assertDataFrameEqual(df1, df2, checkRowOrder=True) + + self.check_error( + exception=pe.exception, + error_class="INVALID_TYPE_DF_EQUALITY_ARG", + message_parameters={ + "expected_type": Union[DataFrame, ps.DataFrame, List[Row]], + "arg_name": "actual", + "actual_type": None, + }, + ) + + def test_assert_unequal_null_expected(self): + df1 = self.spark.createDataFrame( + data=[ + ("1", 1000), + ("2", 3000), + ], + schema=["id", "amount"], + ) + df2 = None + + with self.assertRaises(PySparkAssertionError) as pe: + assertDataFrameEqual(df1, df2) + self.check_error( + exception=pe.exception, + error_class="INVALID_TYPE_DF_EQUALITY_ARG", + message_parameters={ + "expected_type": Union[DataFrame, ps.DataFrame, List[Row]], + "arg_name": "expected", + "actual_type": None, + }, + ) + + with self.assertRaises(PySparkAssertionError) as pe: + assertDataFrameEqual(df1, df2, checkRowOrder=True) + + self.check_error( + exception=pe.exception, + error_class="INVALID_TYPE_DF_EQUALITY_ARG", + message_parameters={ + "expected_type": Union[DataFrame, ps.DataFrame, List[Row]], + "arg_name": "expected", + "actual_type": None, + }, + ) + + def test_assert_equal_exact_pandas_df(self): df1 = ps.DataFrame(data=[10, 20, 30], columns=["Numbers"]) df2 = ps.DataFrame(data=[10, 20, 30], columns=["Numbers"]) @@ -682,16 +753,12 @@ def test_assert_equal_exact_pandas_df(self): assertDataFrameEqual(df1, df2, checkRowOrder=True) def test_assert_equal_exact_pandas_df(self): - import pyspark.pandas as ps - df1 = ps.DataFrame(data=[10, 20, 30], columns=["Numbers"]) df2 = ps.DataFrame(data=[30, 20, 10], columns=["Numbers"]) assertDataFrameEqual(df1, df2) def test_assert_equal_approx_pandas_df(self): - import pyspark.pandas as ps - df1 = ps.DataFrame(data=[10.0001, 20.32, 30.1], columns=["Numbers"]) df2 = ps.DataFrame(data=[10.0, 20.32, 30.1], columns=["Numbers"]) @@ -699,7 +766,6 @@ def test_assert_equal_approx_pandas_df(self): assertDataFrameEqual(df1, df2, checkRowOrder=True) def test_assert_error_pandas_pyspark_df(self): - import pyspark.pandas as ps import pandas as pd df1 = ps.DataFrame(data=[10, 20, 30], columns=["Numbers"]) @@ -742,8 +808,6 @@ def test_assert_error_pandas_pyspark_df(self): ) def test_assert_error_non_pyspark_df(self): - import pyspark.pandas as ps - dict1 = {"a": 1, "b": 2} dict2 = {"a": 1, "b": 2} diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index db615d339b5ae..092fa43b1d2e7 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -2402,7 +2402,7 @@ def __repr__(self) -> str: "%s=%r" % (k, v) for k, v in zip(self.__fields__, tuple(self)) ) else: - return "" % ", ".join("%r" % field for field in self) + return "" % ", ".join(repr(field) for field in self) class DateConverter: diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py index d14a263f839c9..027a2646a4657 100644 --- a/python/pyspark/sql/udtf.py +++ b/python/pyspark/sql/udtf.py @@ -17,15 +17,17 @@ """ User-defined table function related classes and functions """ +import pickle from dataclasses import dataclass +from functools import wraps import inspect import sys import warnings -from typing import Any, Iterator, Type, TYPE_CHECKING, Optional, Union +from typing import Any, Iterable, Iterator, Type, TYPE_CHECKING, Optional, Union, Callable from py4j.java_gateway import JavaObject -from pyspark.errors import PySparkAttributeError, PySparkTypeError +from pyspark.errors import PySparkAttributeError, PySparkRuntimeError, PySparkTypeError from pyspark.rdd import PythonEvalType from pyspark.sql.column import _to_java_column, _to_seq from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version @@ -104,11 +106,11 @@ def _create_py_udtf( from pyspark.sql import SparkSession session = SparkSession._instantiatedSession - arrow_enabled = ( - session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled") == "true" - if session is not None - else True - ) + arrow_enabled = False + if session is not None: + value = session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled") + if isinstance(value, str) and value.lower() == "true": + arrow_enabled = True # Create a regular Python UDTF and check for invalid handler class. regular_udtf = _create_udtf(cls, returnType, name, PythonEvalType.SQL_TABLE_UDF, deterministic) @@ -143,6 +145,20 @@ def _vectorize_udtf(cls: Type) -> Type: """Vectorize a Python UDTF handler class.""" import pandas as pd + # Wrap the exception thrown from the UDTF in a PySparkRuntimeError. + def wrap_func(f: Callable[..., Any]) -> Callable[..., Any]: + @wraps(f) + def evaluate(*a: Any) -> Any: + try: + return f(*a) + except Exception as e: + raise PySparkRuntimeError( + error_class="UDTF_EXEC_ERROR", + message_parameters={"method_name": f.__name__, "error": str(e)}, + ) + + return evaluate + class VectorizedUDTF: def __init__(self) -> None: self.func = cls() @@ -157,17 +173,26 @@ def analyze(*args: AnalyzeArgument) -> AnalyzeResult: def eval(self, *args: pd.Series) -> Iterator[pd.DataFrame]: if len(args) == 0: - yield pd.DataFrame(self.func.eval()) + yield pd.DataFrame(wrap_func(self.func.eval)()) else: # Create tuples from the input pandas Series, each tuple # represents a row across all Series. row_tuples = zip(*args) for row in row_tuples: - yield pd.DataFrame(self.func.eval(*row)) + res = wrap_func(self.func.eval)(*row) + if res is not None and not isinstance(res, Iterable): + raise PySparkRuntimeError( + error_class="UDTF_RETURN_NOT_ITERABLE", + message_parameters={ + "type": type(res).__name__, + }, + ) + yield pd.DataFrame(res) + + if hasattr(cls, "terminate"): - def terminate(self) -> Iterator[pd.DataFrame]: - if hasattr(self.func, "terminate"): - yield pd.DataFrame(self.func.terminate()) + def terminate(self) -> Iterator[pd.DataFrame]: + yield pd.DataFrame(wrap_func(self.func.terminate)()) vectorized_udtf = VectorizedUDTF vectorized_udtf.__name__ = cls.__name__ @@ -279,7 +304,29 @@ def _create_judtf(self, func: Type) -> JavaObject: spark = SparkSession._getActiveSessionOrCreate() sc = spark.sparkContext - wrapped_func = _wrap_function(sc, func) + try: + wrapped_func = _wrap_function(sc, func) + except pickle.PicklingError as e: + if "CONTEXT_ONLY_VALID_ON_DRIVER" in str(e): + raise PySparkRuntimeError( + error_class="UDTF_SERIALIZATION_ERROR", + message_parameters={ + "name": self._name, + "message": "it appears that you are attempting to reference SparkSession " + "inside a UDTF. SparkSession can only be used on the driver, " + "not in code that runs on workers. Please remove the reference " + "and try again.", + }, + ) from None + raise PySparkRuntimeError( + error_class="UDTF_SERIALIZATION_ERROR", + message_parameters={ + "name": self._name, + "message": "Please check the stack trace and make sure the " + "function is serializable.", + }, + ) + assert sc._jvm is not None if self.returnType is None: judtf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonTableFunction( diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 8b520ed653f8c..d4f56fe822f3e 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import inspect import functools import os from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING, cast, TypeVar, Union, Type @@ -258,6 +259,23 @@ def wrapped(*args: Any, **kwargs: Any) -> Any: return cast(FuncT, wrapped) +def try_remote_session_classmethod(f: FuncT) -> FuncT: + """Mark API supported from Spark Connect.""" + + @functools.wraps(f) + def wrapped(*args: Any, **kwargs: Any) -> Any: + + if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: + from pyspark.sql.connect.session import SparkSession # type: ignore[misc] + + assert inspect.isclass(args[0]) + return getattr(SparkSession, f.__name__)(*args[1:], **kwargs) + else: + return f(*args, **kwargs) + + return cast(FuncT, wrapped) + + def pyspark_column_op( func_name: str, left: "IndexOpsLike", right: Any, fillna: Any = None ) -> Union["SeriesOrIndex", None]: diff --git a/python/pyspark/sql/worker/analyze_udtf.py b/python/pyspark/sql/worker/analyze_udtf.py index 44dcd8c892c8e..9ffa03541e695 100644 --- a/python/pyspark/sql/worker/analyze_udtf.py +++ b/python/pyspark/sql/worker/analyze_udtf.py @@ -40,6 +40,7 @@ pickleSer, send_accumulator_updates, setup_broadcasts, + setup_memory_limits, setup_spark_files, utf8_deserializer, ) @@ -96,6 +97,10 @@ def main(infile: IO, outfile: IO) -> None: """ try: check_python_version(infile) + + memory_limit_mb = int(os.environ.get("PYSPARK_UDTF_ANALYZER_MEMORY_MB", "-1")) + setup_memory_limits(memory_limit_mb) + setup_spark_files(infile) setup_broadcasts(infile) diff --git a/python/pyspark/testing/pandasutils.py b/python/pyspark/testing/pandasutils.py index 5899925352144..39196873482b1 100644 --- a/python/pyspark/testing/pandasutils.py +++ b/python/pyspark/testing/pandasutils.py @@ -124,10 +124,10 @@ def _assert_pandas_equal( raise PySparkAssertionError( error_class="DIFFERENT_PANDAS_SERIES", message_parameters={ - "left": left, - "left_dtype": left.dtype, - "right": right, - "right_dtype": right.dtype, + "left": left.to_string(), + "left_dtype": str(left.dtype), + "right": right.to_string(), + "right_dtype": str(right.dtype), }, ) elif isinstance(left, pd.Index) and isinstance(right, pd.Index): @@ -143,9 +143,9 @@ def _assert_pandas_equal( error_class="DIFFERENT_PANDAS_INDEX", message_parameters={ "left": left, - "left_dtype": left.dtype, + "left_dtype": str(left.dtype), "right": right, - "right_dtype": right.dtype, + "right_dtype": str(right.dtype), }, ) else: @@ -228,10 +228,10 @@ def _assert_pandas_almost_equal( raise PySparkAssertionError( error_class="DIFFERENT_PANDAS_SERIES", message_parameters={ - "left": left, - "left_dtype": left.dtype, - "right": right, - "right_dtype": right.dtype, + "left": left.to_string(), + "left_dtype": str(left.dtype), + "right": right.to_string(), + "right_dtype": str(right.dtype), }, ) for lnull, rnull in zip(left.isnull(), right.isnull()): @@ -239,10 +239,10 @@ def _assert_pandas_almost_equal( raise PySparkAssertionError( error_class="DIFFERENT_PANDAS_SERIES", message_parameters={ - "left": left, - "left_dtype": left.dtype, - "right": right, - "right_dtype": right.dtype, + "left": left.to_string(), + "left_dtype": str(left.dtype), + "right": right.to_string(), + "right_dtype": str(right.dtype), }, ) for lval, rval in zip(left.dropna(), right.dropna()): @@ -253,10 +253,10 @@ def _assert_pandas_almost_equal( raise PySparkAssertionError( error_class="DIFFERENT_PANDAS_SERIES", message_parameters={ - "left": left, - "left_dtype": left.dtype, - "right": right, - "right_dtype": right.dtype, + "left": left.to_string(), + "left_dtype": str(left.dtype), + "right": right.to_string(), + "right_dtype": str(right.dtype), }, ) elif isinstance(left, pd.MultiIndex) and isinstance(right, pd.MultiIndex): @@ -265,9 +265,9 @@ def _assert_pandas_almost_equal( error_class="DIFFERENT_PANDAS_MULTIINDEX", message_parameters={ "left": left, - "left_dtype": left.dtype, + "left_dtype": str(left.dtype), "right": right, - "right_dtype": right.dtype, + "right_dtype": str(right.dtype), }, ) for lval, rval in zip(left, right): @@ -279,9 +279,9 @@ def _assert_pandas_almost_equal( error_class="DIFFERENT_PANDAS_MULTIINDEX", message_parameters={ "left": left, - "left_dtype": left.dtype, + "left_dtype": str(left.dtype), "right": right, - "right_dtype": right.dtype, + "right_dtype": str(right.dtype), }, ) elif isinstance(left, pd.Index) and isinstance(right, pd.Index): @@ -290,9 +290,9 @@ def _assert_pandas_almost_equal( error_class="DIFFERENT_PANDAS_INDEX", message_parameters={ "left": left, - "left_dtype": left.dtype, + "left_dtype": str(left.dtype), "right": right, - "right_dtype": right.dtype, + "right_dtype": str(right.dtype), }, ) for lnull, rnull in zip(left.isnull(), right.isnull()): @@ -301,9 +301,9 @@ def _assert_pandas_almost_equal( error_class="DIFFERENT_PANDAS_INDEX", message_parameters={ "left": left, - "left_dtype": left.dtype, + "left_dtype": str(left.dtype), "right": right, - "right_dtype": right.dtype, + "right_dtype": str(right.dtype), }, ) for lval, rval in zip(left.dropna(), right.dropna()): @@ -315,9 +315,9 @@ def _assert_pandas_almost_equal( error_class="DIFFERENT_PANDAS_INDEX", message_parameters={ "left": left, - "left_dtype": left.dtype, + "left_dtype": str(left.dtype), "right": right, - "right_dtype": right.dtype, + "right_dtype": str(right.dtype), }, ) else: diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py index 2a23476112fee..8e02803efe5cb 100644 --- a/python/pyspark/testing/utils.py +++ b/python/pyspark/testing/utils.py @@ -292,6 +292,7 @@ def assertSchemaEqual(actual: StructType, expected: StructType): >>> s1 = StructType([StructField("names", ArrayType(DoubleType(), True), True)]) >>> s2 = StructType([StructField("names", ArrayType(DoubleType(), True), True)]) >>> assertSchemaEqual(s1, s2) # pass, schemas are identical + >>> df1 = spark.createDataFrame(data=[(1, 1000), (2, 3000)], schema=["id", "number"]) >>> df2 = spark.createDataFrame(data=[("1", 1000), ("2", 5000)], schema=["id", "amount"]) >>> assertSchemaEqual(df1.schema, df2.schema) # doctest: +IGNORE_EXCEPTION_DETAIL @@ -414,16 +415,20 @@ def assertDataFrameEqual( >>> df1 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], schema=["id", "amount"]) >>> df2 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], schema=["id", "amount"]) >>> assertDataFrameEqual(df1, df2) # pass, DataFrames are identical + >>> df1 = spark.createDataFrame(data=[("1", 0.1), ("2", 3.23)], schema=["id", "amount"]) >>> df2 = spark.createDataFrame(data=[("1", 0.109), ("2", 3.23)], schema=["id", "amount"]) >>> assertDataFrameEqual(df1, df2, rtol=1e-1) # pass, DataFrames are approx equal by rtol + >>> df1 = spark.createDataFrame(data=[(1, 1000), (2, 3000)], schema=["id", "amount"]) >>> list_of_rows = [Row(1, 1000), Row(2, 3000)] >>> assertDataFrameEqual(df1, list_of_rows) # pass, actual and expected data are equal + >>> import pyspark.pandas as ps >>> df1 = ps.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]}) >>> df2 = ps.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]}) >>> assertDataFrameEqual(df1, df2) # pass, pandas-on-Spark DataFrames are equal + >>> df1 = spark.createDataFrame( ... data=[("1", 1000.00), ("2", 3000.00), ("3", 2000.00)], schema=["id", "amount"]) >>> df2 = spark.createDataFrame( @@ -432,26 +437,39 @@ def assertDataFrameEqual( Traceback (most recent call last): ... PySparkAssertionError: [DIFFERENT_ROWS] Results do not match: ( 66.66667 % ) - --- actual - +++ expected - - Row(id='1', amount=1000.0) - ? ^ - + Row(id='1', amount=1001.0) - ? ^ - - Row(id='3', amount=2000.0) - ? ^ - + Row(id='3', amount=2003.0) - ? ^ - + *** actual *** + ! Row(id='1', amount=1000.0) + Row(id='2', amount=3000.0) + ! Row(id='3', amount=2000.0) + *** expected *** + ! Row(id='1', amount=1001.0) + Row(id='2', amount=3000.0) + ! Row(id='3', amount=2003.0) """ - if actual is None and expected is None: - return True - elif actual is None or expected is None: - return False - import pyspark.pandas as ps from pyspark.testing.pandasutils import assertPandasOnSparkEqual + if actual is None and expected is None: + return True + elif actual is None: + raise PySparkAssertionError( + error_class="INVALID_TYPE_DF_EQUALITY_ARG", + message_parameters={ + "expected_type": Union[DataFrame, ps.DataFrame, List[Row]], + "arg_name": "actual", + "actual_type": None, + }, + ) + elif expected is None: + raise PySparkAssertionError( + error_class="INVALID_TYPE_DF_EQUALITY_ARG", + message_parameters={ + "expected_type": Union[DataFrame, ps.DataFrame, List[Row]], + "arg_name": "expected", + "actual_type": None, + }, + ) + try: # If Spark Connect dependencies are available, allow Spark Connect DataFrame from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 20e856c9addc3..6f27400387e72 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -21,18 +21,11 @@ import os import sys import time -from inspect import currentframe, getframeinfo, getfullargspec +from inspect import getfullargspec import json -from typing import Iterator - -# 'resource' is a Unix specific module. -has_resource_module = True -try: - import resource -except ImportError: - has_resource_module = False +from typing import Iterable, Iterator + import traceback -import warnings import faulthandler from pyspark.accumulators import _accumulatorRegistry @@ -70,6 +63,7 @@ pickleSer, send_accumulator_updates, setup_broadcasts, + setup_memory_limits, setup_spark_files, utf8_deserializer, ) @@ -591,6 +585,7 @@ def read_udtf(pickleSer, infile, eval_type): def wrap_arrow_udtf(f, return_type): arrow_return_type = to_arrow_type(return_type) + return_type_size = len(return_type) def verify_result(result): import pandas as pd @@ -599,7 +594,7 @@ def verify_result(result): raise PySparkTypeError( error_class="INVALID_ARROW_UDTF_RETURN_TYPE", message_parameters={ - "type_name": type(result).__name_, + "type_name": type(result).__name__, "value": str(result), }, ) @@ -609,11 +604,11 @@ def verify_result(result): # result dataframe may contain an empty row. For example, when a UDTF is # defined as follows: def eval(self): yield tuple(). if len(result) > 0 or len(result.columns) > 0: - if len(result.columns) != len(return_type): + if len(result.columns) != return_type_size: raise PySparkRuntimeError( error_class="UDTF_RETURN_SCHEMA_MISMATCH", message_parameters={ - "expected": str(len(return_type)), + "expected": str(return_type_size), "actual": str(len(result.columns)), }, ) @@ -641,13 +636,7 @@ def mapper(_, it): yield from eval(*[a[o] for o in arg_offsets]) finally: if terminate is not None: - try: - yield from terminate() - except BaseException as e: - raise PySparkRuntimeError( - error_class="UDTF_EXEC_ERROR", - message_parameters={"method_name": "terminate", "error": str(e)}, - ) + yield from terminate() return mapper, None, ser, ser @@ -656,32 +645,52 @@ def mapper(_, it): def wrap_udtf(f, return_type): assert return_type.needConversion() toInternal = return_type.toInternal + return_type_size = len(return_type) def verify_and_convert_result(result): - # TODO(SPARK-44005): support returning non-tuple values - if result is not None and hasattr(result, "__len__"): - if len(result) != len(return_type): + if result is not None: + if hasattr(result, "__len__") and len(result) != return_type_size: raise PySparkRuntimeError( error_class="UDTF_RETURN_SCHEMA_MISMATCH", message_parameters={ - "expected": str(len(return_type)), + "expected": str(return_type_size), "actual": str(len(result)), }, ) + + if not (isinstance(result, (list, dict, tuple)) or hasattr(result, "__dict__")): + raise PySparkRuntimeError( + error_class="UDTF_INVALID_OUTPUT_ROW_TYPE", + message_parameters={"type": type(result).__name__}, + ) + return toInternal(result) # Evaluate the function and return a tuple back to the executor. def evaluate(*a) -> tuple: - res = f(*a) + try: + res = f(*a) + except Exception as e: + raise PySparkRuntimeError( + error_class="UDTF_EXEC_ERROR", + message_parameters={"method_name": f.__name__, "error": str(e)}, + ) + if res is None: # If the function returns None or does not have an explicit return statement, # an empty tuple is returned to the executor. # This is because directly constructing tuple(None) results in an exception. return tuple() - else: - # If the function returns a result, we map it to the internal representation and - # returns the results as a tuple. - return tuple(map(verify_and_convert_result, res)) + + if not isinstance(res, Iterable): + raise PySparkRuntimeError( + error_class="UDTF_RETURN_NOT_ITERABLE", + message_parameters={"type": type(res).__name__}, + ) + + # If the function returns a result, we map it to the internal representation and + # returns the results as a tuple. + return tuple(map(verify_and_convert_result, res)) return evaluate @@ -699,13 +708,7 @@ def mapper(_, it): yield eval(*[a[o] for o in arg_offsets]) finally: if terminate is not None: - try: - yield terminate() - except BaseException as e: - raise PySparkRuntimeError( - error_class="UDTF_EXEC_ERROR", - message_parameters={"method_name": "terminate", "error": str(e)}, - ) + yield terminate() return mapper, None, ser, ser @@ -995,38 +998,8 @@ def main(infile, outfile): boundPort = read_int(infile) secret = UTF8Deserializer().loads(infile) - # set up memory limits memory_limit_mb = int(os.environ.get("PYSPARK_EXECUTOR_MEMORY_MB", "-1")) - if memory_limit_mb > 0 and has_resource_module: - total_memory = resource.RLIMIT_AS - try: - (soft_limit, hard_limit) = resource.getrlimit(total_memory) - msg = "Current mem limits: {0} of max {1}\n".format(soft_limit, hard_limit) - print(msg, file=sys.stderr) - - # convert to bytes - new_limit = memory_limit_mb * 1024 * 1024 - - if soft_limit == resource.RLIM_INFINITY or new_limit < soft_limit: - msg = "Setting mem limits to {0} of max {1}\n".format(new_limit, new_limit) - print(msg, file=sys.stderr) - resource.setrlimit(total_memory, (new_limit, new_limit)) - - except (resource.error, OSError, ValueError) as e: - # not all systems support resource limits, so warn instead of failing - lineno = ( - getframeinfo(currentframe()).lineno + 1 if currentframe() is not None else 0 - ) - if "__file__" in globals(): - print( - warnings.formatwarning( - "Failed to set memory limit: {0}".format(e), - ResourceWarning, - __file__, - lineno, - ), - file=sys.stderr, - ) + setup_memory_limits(memory_limit_mb) # initialize global state taskContext = None diff --git a/python/pyspark/worker_util.py b/python/pyspark/worker_util.py index eab0daf8f592b..9f6d46c6211d5 100644 --- a/python/pyspark/worker_util.py +++ b/python/pyspark/worker_util.py @@ -19,9 +19,18 @@ Util functions for workers. """ import importlib +from inspect import currentframe, getframeinfo import os import sys from typing import Any, IO +import warnings + +# 'resource' is a Unix specific module. +has_resource_module = True +try: + import resource +except ImportError: + has_resource_module = False from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry @@ -71,6 +80,44 @@ def check_python_version(infile: IO) -> None: ) +def setup_memory_limits(memory_limit_mb: int) -> None: + """ + Sets up the memory limits. + + If memory_limit_mb > 0 and `resource` module is available, sets the memory limit. + Windows does not support resource limiting and actual resource is not limited on MacOS. + """ + if memory_limit_mb > 0 and has_resource_module: + total_memory = resource.RLIMIT_AS + try: + (soft_limit, hard_limit) = resource.getrlimit(total_memory) + msg = "Current mem limits: {0} of max {1}\n".format(soft_limit, hard_limit) + print(msg, file=sys.stderr) + + # convert to bytes + new_limit = memory_limit_mb * 1024 * 1024 + + if soft_limit == resource.RLIM_INFINITY or new_limit < soft_limit: + msg = "Setting mem limits to {0} of max {1}\n".format(new_limit, new_limit) + print(msg, file=sys.stderr) + resource.setrlimit(total_memory, (new_limit, new_limit)) + + except (resource.error, OSError, ValueError) as e: + # not all systems support resource limits, so warn instead of failing + curent = currentframe() + lineno = getframeinfo(curent).lineno + 1 if curent is not None else 0 + if "__file__" in globals(): + print( + warnings.formatwarning( + "Failed to set memory limit: {0}".format(e), + ResourceWarning, + __file__, + lineno, + ), + file=sys.stderr, + ) + + def setup_spark_files(infile: IO) -> None: """ Set up Spark files, archives, and pyfiles. diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala index 4809222650d82..6953ed789f797 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala @@ -86,12 +86,20 @@ class ExecutorPodsWatchSnapshotSource( } override def onClose(e: WatcherException): Unit = { - logWarning("Kubernetes client has been closed (this is expected if the application is" + - " shutting down.)", e) + if (SparkContext.getActive.map(_.isStopped).getOrElse(true)) { + logInfo("Kubernetes client has been closed.") + } else { + logWarning("Kubernetes client has been closed (this is expected if the application is" + + " shutting down.)", e) + } } override def onClose(): Unit = { - logWarning("Kubernetes client has been closed.") + if (SparkContext.getActive.map(_.isStopped).getOrElse(true)) { + logInfo("Kubernetes client has been closed.") + } else { + logWarning("Kubernetes client has been closed.") + } } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java b/sql/api/src/main/java/org/apache/spark/sql/RowFactory.java similarity index 100% rename from sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java rename to sql/api/src/main/java/org/apache/spark/sql/RowFactory.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java b/sql/api/src/main/java/org/apache/spark/sql/streaming/Trigger.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java rename to sql/api/src/main/java/org/apache/spark/sql/streaming/Trigger.java diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index f352d28a7b501..3d536b735db59 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -18,12 +18,14 @@ package org.apache.spark.sql.catalyst import java.beans.{Introspector, PropertyDescriptor} import java.lang.reflect.{ParameterizedType, Type, TypeVariable} -import java.util.{ArrayDeque, List => JList, Map => JMap} +import java.util.{List => JList, Map => JMap} import javax.annotation.Nonnull -import scala.annotation.tailrec +import scala.collection.JavaConverters._ import scala.reflect.ClassTag +import org.apache.commons.lang3.reflect.{TypeUtils => JavaTypeUtils} + import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, DayTimeIntervalEncoder, DEFAULT_JAVA_DECIMAL_ENCODER, EncoderField, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaEnumEncoder, LocalDateTimeEncoder, MapEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, STRICT_DATE_ENCODER, STRICT_INSTANT_ENCODER, STRICT_LOCAL_DATE_ENCODER, STRICT_TIMESTAMP_ENCODER, StringEncoder, UDTEncoder, YearMonthIntervalEncoder} import org.apache.spark.sql.errors.ExecutionErrors @@ -57,7 +59,8 @@ object JavaTypeInference { encoderFor(beanType, Set.empty).asInstanceOf[AgnosticEncoder[T]] } - private def encoderFor(t: Type, seenTypeSet: Set[Class[_]]): AgnosticEncoder[_] = t match { + private def encoderFor(t: Type, seenTypeSet: Set[Class[_]], + typeVariables: Map[TypeVariable[_], Type] = Map.empty): AgnosticEncoder[_] = t match { case c: Class[_] if c == java.lang.Boolean.TYPE => PrimitiveBooleanEncoder case c: Class[_] if c == java.lang.Byte.TYPE => PrimitiveByteEncoder @@ -101,18 +104,24 @@ object JavaTypeInference { UDTEncoder(udt, udt.getClass) case c: Class[_] if c.isArray => - val elementEncoder = encoderFor(c.getComponentType, seenTypeSet) + val elementEncoder = encoderFor(c.getComponentType, seenTypeSet, typeVariables) ArrayEncoder(elementEncoder, elementEncoder.nullable) - case ImplementsList(c, Array(elementCls)) => - val element = encoderFor(elementCls, seenTypeSet) + case c: Class[_] if classOf[JList[_]].isAssignableFrom(c) => + val element = encoderFor(c.getTypeParameters.array(0), seenTypeSet, typeVariables) IterableEncoder(ClassTag(c), element, element.nullable, lenientSerialization = false) - case ImplementsMap(c, Array(keyCls, valueCls)) => - val keyEncoder = encoderFor(keyCls, seenTypeSet) - val valueEncoder = encoderFor(valueCls, seenTypeSet) + case c: Class[_] if classOf[JMap[_, _]].isAssignableFrom(c) => + val keyEncoder = encoderFor(c.getTypeParameters.array(0), seenTypeSet, typeVariables) + val valueEncoder = encoderFor(c.getTypeParameters.array(1), seenTypeSet, typeVariables) MapEncoder(ClassTag(c), keyEncoder, valueEncoder, valueEncoder.nullable) + case tv: TypeVariable[_] => + encoderFor(typeVariables(tv), seenTypeSet, typeVariables) + + case pt: ParameterizedType => + encoderFor(pt.getRawType, seenTypeSet, JavaTypeUtils.getTypeArguments(pt).asScala.toMap) + case c: Class[_] => if (seenTypeSet.contains(c)) { throw ExecutionErrors.cannotHaveCircularReferencesInBeanClassError(c) @@ -124,7 +133,7 @@ object JavaTypeInference { // Note that the fields are ordered by name. val fields = properties.map { property => val readMethod = property.getReadMethod - val encoder = encoderFor(readMethod.getGenericReturnType, seenTypeSet + c) + val encoder = encoderFor(readMethod.getGenericReturnType, seenTypeSet + c, typeVariables) // The existence of `javax.annotation.Nonnull`, means this field is not nullable. val hasNonNull = readMethod.isAnnotationPresent(classOf[Nonnull]) EncoderField( @@ -147,59 +156,4 @@ object JavaTypeInference { .filterNot(_.getName == "declaringClass") .filter(_.getReadMethod != null) } - - private class ImplementsGenericInterface(interface: Class[_]) { - assert(interface.isInterface) - assert(interface.getTypeParameters.nonEmpty) - - def unapply(t: Type): Option[(Class[_], Array[Type])] = implementsInterface(t).map { cls => - cls -> findTypeArgumentsForInterface(t) - } - - @tailrec - private def implementsInterface(t: Type): Option[Class[_]] = t match { - case pt: ParameterizedType => implementsInterface(pt.getRawType) - case c: Class[_] if interface.isAssignableFrom(c) => Option(c) - case _ => None - } - - private def findTypeArgumentsForInterface(t: Type): Array[Type] = { - val queue = new ArrayDeque[(Type, Map[Any, Type])] - queue.add(t -> Map.empty) - while (!queue.isEmpty) { - queue.poll() match { - case (pt: ParameterizedType, bindings) => - // translate mappings... - val mappedTypeArguments = pt.getActualTypeArguments.map { - case v: TypeVariable[_] => bindings(v.getName) - case v => v - } - if (pt.getRawType == interface) { - return mappedTypeArguments - } else { - val mappedTypeArgumentMap = mappedTypeArguments - .zipWithIndex.map(_.swap) - .toMap[Any, Type] - queue.add(pt.getRawType -> mappedTypeArgumentMap) - } - case (c: Class[_], indexedBindings) => - val namedBindings = c.getTypeParameters.zipWithIndex.map { - case (parameter, index) => - parameter.getName -> indexedBindings(index) - }.toMap[Any, Type] - val superClass = c.getGenericSuperclass - if (superClass != null) { - queue.add(superClass -> namedBindings) - } - c.getGenericInterfaces.foreach { iface => - queue.add(iface -> namedBindings) - } - } - } - throw ExecutionErrors.unreachableError() - } - } - - private object ImplementsList extends ImplementsGenericInterface(classOf[JList[_]]) - private object ImplementsMap extends ImplementsGenericInterface(classOf[JMap[_, _]]) } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala index 7a34a386cd889..5e52e283338d3 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala @@ -192,15 +192,7 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { decimalPrecision: Int, decimalScale: Int, context: SQLQueryContext = null): ArithmeticException = { - new SparkArithmeticException( - errorClass = "NUMERIC_VALUE_OUT_OF_RANGE", - messageParameters = Map( - "value" -> value.toPlainString, - "precision" -> decimalPrecision.toString, - "scale" -> decimalScale.toString, - "config" -> toSQLConf("spark.sql.ansi.enabled")), - context = getQueryContext(context), - summary = getSummary(context)) + numericValueOutOfRange(value, decimalPrecision, decimalScale, context) } def cannotChangeDecimalPrecisionError( @@ -208,6 +200,14 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { decimalPrecision: Int, decimalScale: Int, context: SQLQueryContext = null): ArithmeticException = { + numericValueOutOfRange(value, decimalPrecision, decimalScale, context) + } + + private def numericValueOutOfRange( + value: Decimal, + decimalPrecision: Int, + decimalScale: Int, + context: SQLQueryContext): ArithmeticException = { new SparkArithmeticException( errorClass = "NUMERIC_VALUE_OUT_OF_RANGE", messageParameters = Map( diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala b/sql/api/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala similarity index 96% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala rename to sql/api/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala index ad19ad1780549..37c5b314978bb 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala @@ -28,8 +28,6 @@ import org.apache.spark.sql.streaming.Trigger import org.apache.spark.unsafe.types.UTF8String private object Triggers { - // This is a copy of the same class in sql/core/...execution/streaming/Triggers.scala - def validate(intervalMs: Long): Unit = { require(intervalMs >= 0, "the interval of trigger should not be negative") } @@ -87,8 +85,8 @@ object ProcessingTimeTrigger { } /** - * A [[Trigger]] that continuously processes streaming data, asynchronously checkpointing at the - * specified interval. + * A [[Trigger]] that continuously processes streaming data, asynchronously checkpointing at + * the specified interval. */ case class ContinuousTrigger(intervalMs: Long) extends Trigger { Triggers.validate(intervalMs) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/Metadata.scala index 4e7ac996d31e1..3677927b9a555 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -137,6 +137,8 @@ object Metadata { jObj.obj.foreach { case (key, JInt(value)) => builder.putLong(key, value.toLong) + case (key, JLong(value)) => + builder.putLong(key, value.toLong) case (key, JDouble(value)) => builder.putDouble(key, value) case (key, JBool(value)) => @@ -153,6 +155,8 @@ object Metadata { value.head match { case _: JInt => builder.putLongArray(key, value.asInstanceOf[List[JInt]].map(_.num.toLong).toArray) + case _: JLong => + builder.putLongArray(key, value.asInstanceOf[List[JLong]].map(_.num.toLong).toArray) case _: JDouble => builder.putDoubleArray(key, value.asInstanceOf[List[JDouble]].map(_.num).toArray) case _: JBool => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index a419804488654..9b95f74db3a49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -178,6 +178,13 @@ object Encoders { */ def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder.javaBean(beanClass) + /** + * Creates a [[Row]] encoder for schema `schema`. + * + * @since 3.5.0 + */ + def row(schema: StructType): Encoder[Row] = ExpressionEncoder(schema) + /** * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. * This encoder maps T into a single byte array (binary) field. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1de745baa0544..6c1d774a1b5fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -759,7 +759,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved) || (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved)) || !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p - case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => + case p @ Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => if (!RowOrdering.isOrderable(pivotColumn.dataType)) { throw QueryCompilationErrors.unorderablePivotColError(pivotColumn) } @@ -823,7 +823,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor Alias(ExtractValue(pivotAtt, Literal(i), resolver), outputName(value, aggregate))() } } - Project(groupByExprsAttr ++ pivotOutputs, secondAgg) + val newProject = Project(groupByExprsAttr ++ pivotOutputs, secondAgg) + newProject.copyTagsFrom(p) + newProject } else { val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => def ifExpr(e: Expression) = { @@ -857,7 +859,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor Alias(filteredAggregate, outputName(value, aggregate))() } } - Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child) + val newAggregate = Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child) + newAggregate.copyTagsFrom(p) + newAggregate } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala index cebc1e25f9213..ead323ce9857b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.errors.QueryCompilationErrors /** * A virtual rule to resolve [[UnresolvedAttribute]] in [[UpdateTable]]. It's only used by the real * rule `ResolveReferences`. The column resolution order for [[UpdateTable]] is: - * 1. Resolves the column to `AttributeReference`` with the output of the child plan. This + * 1. Resolves the column to `AttributeReference` with the output of the child plan. This * includes metadata columns as well. * 2. Resolves the column to a literal function which is allowed to be invoked without braces, e.g. * `SELECT col, current_date FROM t`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 388edb9024ca1..91c17a475cd94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -231,8 +231,8 @@ class JacksonParser( Float.PositiveInfinity case "-INF" | "-Infinity" if options.allowNonNumericNumbers => Float.NegativeInfinity - case _ => throw QueryExecutionErrors.cannotParseStringAsDataTypeError( - parser, VALUE_STRING, FloatType) + case _ => throw StringAsDataTypeException(parser.getCurrentName, parser.getText, + FloatType) } } @@ -250,8 +250,8 @@ class JacksonParser( Double.PositiveInfinity case "-INF" | "-Infinity" if options.allowNonNumericNumbers => Double.NegativeInfinity - case _ => throw QueryExecutionErrors.cannotParseStringAsDataTypeError( - parser, VALUE_STRING, DoubleType) + case _ => throw StringAsDataTypeException(parser.getCurrentName, parser.getText, + DoubleType) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9fc664bb1c26d..f83cd36f0a82b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -157,7 +157,7 @@ abstract class Optimizer(catalogManager: CatalogManager) // since the other rules might make two separate Unions operators adjacent. Batch("Inline CTE", Once, InlineCTE()) :: - Batch("Union", Once, + Batch("Union", fixedPoint, RemoveNoopOperators, CombineUnions, RemoveNoopUnion) :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 8cb560199c069..7b44539929c84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -122,8 +122,6 @@ object ConstantPropagation extends Rule[LogicalPlan] { } } - type EqualityPredicates = Seq[((AttributeReference, Literal), BinaryComparison)] - /** * Traverse a condition as a tree and replace attributes with constant values. * - On matching [[And]], recursively traverse each children and get propagated mappings. @@ -140,23 +138,23 @@ object ConstantPropagation extends Rule[LogicalPlan] { * resulted false * @return A tuple including: * 1. Option[Expression]: optional changed condition after traversal - * 2. EqualityPredicates: propagated mapping of attribute => constant + * 2. AttributeMap: propagated mapping of attribute => constant */ private def traverse(condition: Expression, replaceChildren: Boolean, nullIsFalse: Boolean) - : (Option[Expression], EqualityPredicates) = + : (Option[Expression], AttributeMap[(Literal, BinaryComparison)]) = condition match { case e @ EqualTo(left: AttributeReference, right: Literal) if safeToReplace(left, nullIsFalse) => - (None, Seq(((left, right), e))) + (None, AttributeMap(Map(left -> (right, e)))) case e @ EqualTo(left: Literal, right: AttributeReference) if safeToReplace(right, nullIsFalse) => - (None, Seq(((right, left), e))) + (None, AttributeMap(Map(right -> (left, e)))) case e @ EqualNullSafe(left: AttributeReference, right: Literal) if safeToReplace(left, nullIsFalse) => - (None, Seq(((left, right), e))) + (None, AttributeMap(Map(left -> (right, e)))) case e @ EqualNullSafe(left: Literal, right: AttributeReference) if safeToReplace(right, nullIsFalse) => - (None, Seq(((right, left), e))) + (None, AttributeMap(Map(right -> (left, e)))) case a: And => val (newLeft, equalityPredicatesLeft) = traverse(a.left, replaceChildren = false, nullIsFalse) @@ -183,12 +181,12 @@ object ConstantPropagation extends Rule[LogicalPlan] { } else { None } - (newSelf, Seq.empty) + (newSelf, AttributeMap.empty) case n: Not => // Ignore the EqualityPredicates from children since they are only propagated through And. val (newChild, _) = traverse(n.child, replaceChildren = true, nullIsFalse = false) - (newChild.map(Not), Seq.empty) - case _ => (None, Seq.empty) + (newChild.map(Not), AttributeMap.empty) + case _ => (None, AttributeMap.empty) } // We need to take into account if an attribute is nullable and the context of the conjunctive @@ -199,16 +197,15 @@ object ConstantPropagation extends Rule[LogicalPlan] { private def safeToReplace(ar: AttributeReference, nullIsFalse: Boolean) = !ar.nullable || nullIsFalse - private def replaceConstants(condition: Expression, equalityPredicates: EqualityPredicates) - : Expression = { - val constantsMap = AttributeMap(equalityPredicates.map(_._1)) - val predicates = equalityPredicates.map(_._2).toSet - def replaceConstants0(expression: Expression) = expression transform { - case a: AttributeReference => constantsMap.getOrElse(a, a) - } + private def replaceConstants( + condition: Expression, + equalityPredicates: AttributeMap[(Literal, BinaryComparison)]): Expression = { + val constantsMap = AttributeMap(equalityPredicates.map { case (attr, (lit, _)) => attr -> lit }) + val predicates = equalityPredicates.values.map(_._2).toSet condition transform { - case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants0(e) - case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants0(e) + case b: BinaryComparison if !predicates.contains(b) => b transform { + case a: AttributeReference => constantsMap.getOrElse(a, a) + } } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala index 4a2b9eae98100..1088655f60cd4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala @@ -104,7 +104,7 @@ object NamedParametersSupport { val positionalParametersSet = allParameterNames.take(positionalArgs.size).toSet val namedParametersSet = collection.mutable.Set[String]() - for (arg <- namedArgs) { + namedArgs.zipWithIndex.foreach { case (arg, index) => arg match { case namedArg: NamedArgumentExpression => val parameterName = namedArg.key @@ -122,7 +122,8 @@ object NamedParametersSupport { } namedParametersSet.add(namedArg.key) case _ => - throw QueryCompilationErrors.unexpectedPositionalArgument(functionName) + throw QueryCompilationErrors.unexpectedPositionalArgument( + functionName, namedArgs(index - 1).asInstanceOf[NamedArgumentExpression].key) } } @@ -141,15 +142,16 @@ object NamedParametersSupport { }.toMap // We rearrange named arguments to match their positional order. - val rearrangedNamedArgs: Seq[Expression] = namedParameters.map { param => - namedArgMap.getOrElse( - param.name, - if (param.default.isEmpty) { - throw QueryCompilationErrors.requiredParameterNotFound(functionName, param.name) - } else { - param.default.get - } - ) + val rearrangedNamedArgs: Seq[Expression] = namedParameters.zipWithIndex.map { + case (param, index) => + namedArgMap.getOrElse( + param.name, + if (param.default.isEmpty) { + throw QueryCompilationErrors.requiredParameterNotFound(functionName, param.name, index) + } else { + param.default.get + } + ) } val rearrangedArgs = positionalArgs ++ rearrangedNamedArgs assert(rearrangedArgs.size == parameters.size) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index bd8ba54ddd736..456005768bd42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -313,7 +313,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) * by `expressions`. `partitionValues`, if defined, should contain value of partition key(s) in * ascending order, after evaluated by the transforms in `expressions`, for each input partition. * In addition, its length must be the same as the number of input partitions (and thus is a 1-1 - * mapping), and each row in `partitionValues` must be unique. + * mapping). The `partitionValues` may contain duplicated partition values. * * For example, if `expressions` is `[years(ts_col)]`, then a valid value of `partitionValues` is * `[0, 1, 2]`, which represents 3 input partitions with distinct partition values. All rows @@ -355,6 +355,13 @@ case class KeyGroupedPartitioning( override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = KeyGroupedShuffleSpec(this, distribution) + + lazy val uniquePartitionValues: Seq[InternalRow] = { + partitionValues + .map(InternalRowComparableWrapper(_, expressions)) + .distinct + .map(_.row) + } } object KeyGroupedPartitioning { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala index e1223a71f746b..7bf01fba8cd9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.DataType import org.apache.spark.unsafe.types.UTF8String /** @@ -56,3 +57,11 @@ case class BadRecordException( * Exception thrown when the underlying parser parses a JSON array as a struct. */ case class JsonArraysAsStructsException() extends RuntimeException() + +/** + * Exception thrown when the underlying parser can not parses a String as a datatype. + */ +case class StringAsDataTypeException( + fieldName: String, + fieldValue: String, + dataType: DataType) extends RuntimeException() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala index 2a9370b8c91ce..0a5764e21e14e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala @@ -75,6 +75,9 @@ class FailureSafeParser[IN]( // SPARK-42298 we recreate the exception here to make sure the error message // have the record content. throw QueryExecutionErrors.cannotParseJsonArraysAsStructsError(e.record().toString) + case StringAsDataTypeException(fieldName, fieldValue, dataType) => + throw QueryExecutionErrors.cannotParseStringAsDataTypeError(e.record().toString, + fieldName, fieldValue, dataType) case _ => throw QueryExecutionErrors.malformedRecordsDetectedInRecordParsingError( toResultRow(e.partialResults().headOption, e.record).toString, e) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 41de0c76b3b00..1e4f779e565af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -90,12 +90,13 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat } def requiredParameterNotFound( - functionName: String, parameterName: String) : Throwable = { + functionName: String, parameterName: String, index: Int) : Throwable = { new AnalysisException( errorClass = "REQUIRED_PARAMETER_NOT_FOUND", messageParameters = Map( "functionName" -> toSQLId(functionName), - "parameterName" -> toSQLId(parameterName)) + "parameterName" -> toSQLId(parameterName), + "index" -> index.toString) ) } @@ -115,10 +116,14 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat ) } - def unexpectedPositionalArgument(functionName: String): Throwable = { + def unexpectedPositionalArgument( + functionName: String, + precedingNamedArgument: String): Throwable = { new AnalysisException( errorClass = "UNEXPECTED_POSITIONAL_ARGUMENT", - messageParameters = Map("functionName" -> toSQLId(functionName)) + messageParameters = Map( + "functionName" -> toSQLId(functionName), + "parameterName" -> toSQLId(precedingNamedArgument)) ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryErrorsBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryErrorsBase.scala index db256fbee8785..26600117a0c54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryErrorsBase.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryErrorsBase.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.errors import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} -import org.apache.spark.sql.catalyst.util.{toPrettySQL, QuotingUtils} +import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.types.{DataType, DoubleType, FloatType} /** @@ -55,10 +55,6 @@ private[sql] trait QueryErrorsBase extends DataTypeErrorsBase { quoteByDefault(toPrettySQL(e)) } - def toSQLSchema(schema: String): String = { - QuotingUtils.toSQLSchema(schema) - } - // Converts an error class parameter to its SQL representation def toSQLValue(v: Any, t: DataType): String = Literal.create(v, t) match { case Literal(null, _) => "NULL" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 3622ffebb74d9..f3c5fb4bef3b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -32,7 +32,6 @@ import org.apache.spark._ import org.apache.spark.launcher.SparkLauncher import org.apache.spark.memory.SparkOutOfMemoryError import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.ScalaReflection.Schema import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable} @@ -183,10 +182,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE messageParameters = Map.empty) } - def dataTypeUnsupportedError(dataType: String, failure: String): Throwable = { - DataTypeErrors.dataTypeUnsupportedError(dataType, failure) - } - def failedExecuteUserDefinedFunctionError(functionName: String, inputTypes: String, outputType: String, e: Throwable): Throwable = { new SparkException( @@ -503,10 +498,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE messageParameters = Map("op" -> op.toString(), "pos" -> pos)) } - def unsupportedRoundingMode(roundMode: BigDecimal.RoundingMode.Value): SparkException = { - DataTypeErrors.unsupportedRoundingMode(roundMode) - } - def resolveCannotHandleNestedSchema(plan: LogicalPlan): SparkRuntimeException = { new SparkRuntimeException( errorClass = "_LEGACY_ERROR_TEMP_2030", @@ -1214,52 +1205,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE messageParameters = Map("o" -> o.toString())) } - def unscaledValueTooLargeForPrecisionError( - value: Decimal, - decimalPrecision: Int, - decimalScale: Int, - context: SQLQueryContext = null): ArithmeticException = { - DataTypeErrors.unscaledValueTooLargeForPrecisionError( - value, decimalPrecision, decimalScale, context) - } - - def decimalPrecisionExceedsMaxPrecisionError( - precision: Int, maxPrecision: Int): SparkArithmeticException = { - DataTypeErrors.decimalPrecisionExceedsMaxPrecisionError(precision, maxPrecision) - } - - def outOfDecimalTypeRangeError(str: UTF8String): SparkArithmeticException = { - new SparkArithmeticException( - errorClass = "NUMERIC_OUT_OF_SUPPORTED_RANGE", - messageParameters = Map( - "value" -> str.toString), - context = Array.empty, - summary = "") - } - - def unsupportedArrayTypeError(clazz: Class[_]): SparkRuntimeException = { - DataTypeErrors.unsupportedJavaTypeError(clazz) - } - - def unsupportedJavaTypeError(clazz: Class[_]): SparkRuntimeException = { - DataTypeErrors.unsupportedJavaTypeError(clazz) - } - - def failedParsingStructTypeError(raw: String): SparkRuntimeException = { - new SparkRuntimeException( - errorClass = "FAILED_PARSE_STRUCT_TYPE", - messageParameters = Map("raw" -> toSQLValue(raw, StringType))) - } - - def cannotMergeDecimalTypesWithIncompatibleScaleError( - leftScale: Int, rightScale: Int): Throwable = { - DataTypeErrors.cannotMergeDecimalTypesWithIncompatibleScaleError(leftScale, rightScale) - } - - def cannotMergeIncompatibleDataTypesError(left: DataType, right: DataType): Throwable = { - DataTypeErrors.cannotMergeIncompatibleDataTypesError(left, right) - } - def exceedMapSizeLimitError(size: Int): SparkRuntimeException = { new SparkRuntimeException( errorClass = "_LEGACY_ERROR_TEMP_2126", @@ -1310,15 +1255,20 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE "failFastMode" -> FailFastMode.name)) } - def cannotParseStringAsDataTypeError(parser: JsonParser, token: JsonToken, dataType: DataType) - : SparkRuntimeException = { + def cannotParseStringAsDataTypeError( + recordStr: String, + fieldName: String, + fieldValue: String, + dataType: DataType): SparkRuntimeException = { new SparkRuntimeException( - errorClass = "_LEGACY_ERROR_TEMP_2133", + errorClass = "MALFORMED_RECORD_IN_PARSING.CANNOT_PARSE_STRING_AS_DATATYPE", messageParameters = Map( - "fieldName" -> parser.getCurrentName, - "fieldValue" -> parser.getText, - "token" -> token.toString(), - "dataType" -> dataType.toString())) + "badRecord" -> recordStr, + "failFastMode" -> FailFastMode.name, + "fieldName" -> toSQLId(fieldName), + "fieldValue" -> toSQLValue(fieldValue, StringType), + "inputType" -> StringType.toString, + "targetType" -> dataType.toString)) } def emptyJsonFieldValueError(dataType: DataType): SparkRuntimeException = { @@ -1344,13 +1294,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE messageParameters = Map.empty) } - def attributesForTypeUnsupportedError(schema: Schema): SparkUnsupportedOperationException = { - new SparkUnsupportedOperationException( - errorClass = "_LEGACY_ERROR_TEMP_2142", - messageParameters = Map( - "schema" -> schema.toString())) - } - def paramExceedOneCharError(paramName: String): SparkRuntimeException = { new SparkRuntimeException( errorClass = "_LEGACY_ERROR_TEMP_2145", @@ -1584,9 +1527,8 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def ruleIdNotFoundForRuleError(ruleName: String): Throwable = { new SparkException( - errorClass = "_LEGACY_ERROR_TEMP_2175", - messageParameters = Map( - "ruleName" -> ruleName), + errorClass = "RULE_ID_NOT_FOUND", + messageParameters = Map("ruleName" -> ruleName), cause = null) } @@ -2005,22 +1947,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE cause = null) } - def unsupportedOperationExceptionError(): SparkUnsupportedOperationException = { - DataTypeErrors.unsupportedOperationExceptionError() - } - - def nullLiteralsCannotBeCastedError(name: String): SparkUnsupportedOperationException = { - DataTypeErrors.nullLiteralsCannotBeCastedError(name) - } - - def notUserDefinedTypeError(name: String, userClass: String): Throwable = { - DataTypeErrors.notUserDefinedTypeError(name, userClass) - } - - def cannotLoadUserDefinedTypeError(name: String, userClass: String): Throwable = { - DataTypeErrors.cannotLoadUserDefinedTypeError(name, userClass) - } - def notPublicClassError(name: String): SparkUnsupportedOperationException = { new SparkUnsupportedOperationException( errorClass = "_LEGACY_ERROR_TEMP_2229", @@ -2034,14 +1960,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE messageParameters = Map.empty) } - def fieldIndexOnRowWithoutSchemaError(): SparkUnsupportedOperationException = { - DataTypeErrors.fieldIndexOnRowWithoutSchemaError() - } - - def valueIsNullError(index: Int): Throwable = { - DataTypeErrors.valueIsNullError(index) - } - def onlySupportDataSourcesProvidingFileFormatError(providingClass: String): Throwable = { new SparkException( errorClass = "_LEGACY_ERROR_TEMP_2233", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index dfa2a0f251fea..bcf8ce2bc5407 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2942,7 +2942,19 @@ object SQLConf { .doc("Enable Arrow optimization for Python UDTFs.") .version("3.5.0") .booleanConf - .createWithDefault(true) + .createWithDefault(false) + + val PYTHON_TABLE_UDF_ANALYZER_MEMORY = + buildConf("spark.sql.analyzer.pythonUDTF.analyzeInPython.memory") + .doc("The amount of memory to be allocated to PySpark for Python UDTF analyzer, in MiB " + + "unless otherwise specified. If set, PySpark memory for Python UDTF analyzer will be " + + "limited to this amount. If not set, Spark will not limit Python's " + + "memory use and it is up to the application to avoid exceeding the overhead memory space " + + "shared with other non-JVM processes.\nNote: Windows does not support resource limiting " + + "and actual resource is not limited on MacOS.") + .version("4.0.0") + .bytesConf(ByteUnit.MiB) + .createOptional val PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME = buildConf("spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName") @@ -5012,6 +5024,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def pysparkWorkerPythonExecutable: Option[String] = getConf(SQLConf.PYSPARK_WORKER_PYTHON_EXECUTABLE) + def pythonUDTFAnalyzerMemory: Option[Long] = getConf(PYTHON_TABLE_UDF_ANALYZER_MEMORY) + def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS) diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/JavaBeanWithGenerics.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/JavaBeanWithGenerics.java new file mode 100644 index 0000000000000..b84a3122cf84c --- /dev/null +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/JavaBeanWithGenerics.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst; + +class JavaBeanWithGenerics { + private A attribute; + + private T value; + + public A getAttribute() { + return attribute; + } + + public void setAttribute(A attribute) { + this.attribute = attribute; + } + + public T getValue() { + return value; + } + + public void setValue(T value) { + this.value = value; + } +} + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala index 35f5bf739bfce..6439997609766 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala @@ -66,6 +66,7 @@ class LeafBean { @BeanProperty var period: java.time.Period = _ @BeanProperty var enum: java.time.Month = _ @BeanProperty val readOnlyString = "read-only" + @BeanProperty var genericNestedBean: JavaBeanWithGenerics[String, String] = _ var nonNullString: String = "value" @javax.annotation.Nonnull @@ -184,6 +185,9 @@ class JavaTypeInferenceSuite extends SparkFunSuite { encoderField("date", STRICT_DATE_ENCODER), encoderField("duration", DayTimeIntervalEncoder), encoderField("enum", JavaEnumEncoder(classTag[java.time.Month])), + encoderField("genericNestedBean", JavaBeanEncoder( + ClassTag(classOf[JavaBeanWithGenerics[String, String]]), + Seq(encoderField("attribute", StringEncoder), encoderField("value", StringEncoder)))), encoderField("instant", STRICT_INSTANT_ENCODER), encoderField("localDate", STRICT_LOCAL_DATE_ENCODER), encoderField("localDateTime", LocalDateTimeEncoder), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/NamedParameterFunctionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/NamedParameterFunctionSuite.scala index dd5cb5e7d03c8..99fed4d2ee5d9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/NamedParameterFunctionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/NamedParameterFunctionSuite.scala @@ -40,6 +40,7 @@ case class DummyExpression( } object DummyExpressionBuilder extends ExpressionBuilder { + def defaultFunctionSignature: FunctionSignature = { FunctionSignature(Seq(InputParameter("k1"), InputParameter("k2"), @@ -49,11 +50,12 @@ object DummyExpressionBuilder extends ExpressionBuilder { override def functionSignature: Option[FunctionSignature] = Some(defaultFunctionSignature) + override def build(funcName: String, expressions: Seq[Expression]): Expression = DummyExpression(expressions(0), expressions(1), expressions(2), expressions(3)) } -class NamedArgumentFunctionSuite extends AnalysisTest { +class NamedParameterFunctionSuite extends AnalysisTest { final val k1Arg = Literal("v1") final val k2Arg = NamedArgumentExpression("k2", Literal("v2")) @@ -61,6 +63,7 @@ class NamedArgumentFunctionSuite extends AnalysisTest { final val k4Arg = NamedArgumentExpression("k4", Literal("v4")) final val namedK1Arg = NamedArgumentExpression("k1", Literal("v1-2")) final val args = Seq(k1Arg, k4Arg, k2Arg, k3Arg) + final val expectedSeq = Seq(Literal("v1"), Literal("v2"), Literal("v3"), Literal("v4")) final val signature = DummyExpressionBuilder.defaultFunctionSignature final val illegalSignature = FunctionSignature(Seq( @@ -115,8 +118,8 @@ class NamedArgumentFunctionSuite extends AnalysisTest { checkError( exception = parseRearrangeException(signature, Seq(k1Arg, k2Arg, k3Arg), "foo"), errorClass = "REQUIRED_PARAMETER_NOT_FOUND", - parameters = Map("functionName" -> toSQLId("foo"), "parameterName" -> toSQLId("k4")) - ) + parameters = Map( + "functionName" -> toSQLId("foo"), "parameterName" -> toSQLId("k4"), "index" -> "2")) } test("UNRECOGNIZED_PARAMETER_NAME") { @@ -134,7 +137,7 @@ class NamedArgumentFunctionSuite extends AnalysisTest { exception = parseRearrangeException(signature, Seq(k2Arg, k3Arg, k1Arg, k4Arg), "foo"), errorClass = "UNEXPECTED_POSITIONAL_ARGUMENT", - parameters = Map("functionName" -> toSQLId("foo")) + parameters = Map("functionName" -> toSQLId("foo"), "parameterName" -> toSQLId("k3")) ) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala index f5f1455f94611..106af71a9d653 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala @@ -159,8 +159,9 @@ class ConstantPropagationSuite extends PlanTest { columnA === Literal(1) && columnA === Literal(2) && columnB === Add(columnA, Literal(3))) val correctAnswer = testRelation - .select(columnA) - .where(columnA === Literal(1) && columnA === Literal(2) && columnB === Literal(5)).analyze + .select(columnA, columnB) + .where(Literal.FalseLiteral) + .select(columnA).analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } @@ -186,4 +187,31 @@ class ConstantPropagationSuite extends PlanTest { .analyze comparePlans(Optimize.execute(query2), correctAnswer2) } + + test("SPARK-42500: ConstantPropagation supports more cases") { + comparePlans( + Optimize.execute(testRelation.where(columnA === 1 && columnB > columnA + 2).analyze), + testRelation.where(columnA === 1 && columnB > 3).analyze) + + comparePlans( + Optimize.execute(testRelation.where(columnA === 1 && columnA === 2).analyze), + testRelation.where(Literal.FalseLiteral).analyze) + + comparePlans( + Optimize.execute(testRelation.where(columnA === 1 && columnA === columnA + 2).analyze), + testRelation.where(Literal.FalseLiteral).analyze) + + comparePlans( + Optimize.execute( + testRelation.where((columnA === 1 || columnB === 2) && columnB === 1).analyze), + testRelation.where(columnA === 1 && columnB === 1).analyze) + + comparePlans( + Optimize.execute(testRelation.where(columnA === 1 && columnA === 1).analyze), + testRelation.where(columnA === 1).analyze) + + comparePlans( + Optimize.execute(testRelation.where(Not(columnA === 1 && columnA === columnA + 2)).analyze), + testRelation.where(Not(columnA === 1) || Not(columnA === columnA + 2)).analyze) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 7b2259a6d9945..eda017937d918 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -42,11 +42,10 @@ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions} -import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.catalyst.trees.{TreeNodeTag, TreePattern} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, IntervalUtils} import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId @@ -1093,7 +1092,7 @@ class Dataset[T] private[sql]( Join( joined.left, joined.right, - UsingJoin(JoinType(joinType), usingColumns), + UsingJoin(JoinType(joinType), usingColumns.toIndexedSeq), None, JoinHint.NONE) } @@ -2241,6 +2240,51 @@ class Dataset[T] private[sql]( Offset(Literal(n), logicalPlan) } + // This breaks caching, but it's usually ok because it addresses a very specific use case: + // using union to union many files or partitions. + private def combineUnions(plan: LogicalPlan): LogicalPlan = { + plan.transformDownWithPruning(_.containsPattern(TreePattern.UNION)) { + case Distinct(u: Union) => + Distinct(flattenUnion(u, isUnionDistinct = true)) + // Only handle distinct-like 'Deduplicate', where the keys == output + case Deduplicate(keys: Seq[Attribute], u: Union) if AttributeSet(keys) == u.outputSet => + Deduplicate(keys, flattenUnion(u, true)) + case u: Union => + flattenUnion(u, isUnionDistinct = false) + } + } + + private def flattenUnion(u: Union, isUnionDistinct: Boolean): Union = { + var changed = false + // We only need to look at the direct children of Union, as the nested adjacent Unions should + // have been combined already by previous `Dataset#union` transformations. + val newChildren = u.children.flatMap { + case Distinct(Union(children, byName, allowMissingCol)) + if isUnionDistinct && byName == u.byName && allowMissingCol == u.allowMissingCol => + changed = true + children + // Only handle distinct-like 'Deduplicate', where the keys == output + case Deduplicate(keys: Seq[Attribute], child @ Union(children, byName, allowMissingCol)) + if AttributeSet(keys) == child.outputSet && isUnionDistinct && byName == u.byName && + allowMissingCol == u.allowMissingCol => + changed = true + children + case Union(children, byName, allowMissingCol) + if !isUnionDistinct && byName == u.byName && allowMissingCol == u.allowMissingCol => + changed = true + children + case other => + Seq(other) + } + if (changed) { + val newUnion = Union(newChildren) + newUnion.copyTagsFrom(u) + newUnion + } else { + u + } + } + /** * Returns a new Dataset containing union of rows in this Dataset and another Dataset. * @@ -2272,9 +2316,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def union(other: Dataset[T]): Dataset[T] = withSetOperator { - // This breaks caching, but it's usually ok because it addresses a very specific use case: - // using union to union many files or partitions. - CombineUnions(Union(logicalPlan, other.logicalPlan)) + combineUnions(Union(logicalPlan, other.logicalPlan)) } /** @@ -2366,9 +2408,7 @@ class Dataset[T] private[sql]( * @since 3.1.0 */ def unionByName(other: Dataset[T], allowMissingColumns: Boolean): Dataset[T] = withSetOperator { - // This breaks caching, but it's usually ok because it addresses a very specific use case: - // using union to union many files or partitions. - CombineUnions(Union(logicalPlan :: other.logicalPlan :: Nil, true, allowMissingColumns)) + combineUnions(Union(logicalPlan :: other.logicalPlan :: Nil, true, allowMissingColumns)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index a739fa40c71cb..e5a38967dc3e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -249,12 +249,18 @@ trait FileSourceScanLike extends DataSourceScanExec { private def isDynamicPruningFilter(e: Expression): Boolean = e.exists(_.isInstanceOf[PlanExpression[_]]) + + // This field will be accessed during planning (e.g., `outputPartitioning` relies on it), and can + // only use static filters. @transient lazy val selectedPartitions: Array[PartitionDirectory] = { val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) val startTime = System.nanoTime() - val ret = - relation.location.listFiles( - partitionFilters.filterNot(isDynamicPruningFilter), dataFilters) + // The filters may contain subquery expressions which can't be evaluated during planning. + // Here we filter out subquery expressions and get the static data/partition filters, so that + // they can be used to do pruning at the planning phase. + val staticDataFilters = dataFilters.filterNot(isDynamicPruningFilter) + val staticPartitionFilters = partitionFilters.filterNot(isDynamicPruningFilter) + val ret = relation.location.listFiles(staticPartitionFilters, staticDataFilters) setFilesNumAndSizeMetric(ret, true) val timeTakenMs = NANOSECONDS.toMillis( (System.nanoTime() - startTime) + optimizerMetadataTimeNs) @@ -266,6 +272,7 @@ trait FileSourceScanLike extends DataSourceScanExec { // present. This is because such a filter relies on information that is only available at run // time (for instance the keys used in the other side of a join). @transient protected lazy val dynamicallySelectedPartitions: Array[PartitionDirectory] = { + val dynamicDataFilters = dataFilters.filter(isDynamicPruningFilter) val dynamicPartitionFilters = partitionFilters.filter(isDynamicPruningFilter) if (dynamicPartitionFilters.nonEmpty) { @@ -278,7 +285,11 @@ trait FileSourceScanLike extends DataSourceScanExec { val index = partitionColumns.indexWhere(a.name == _.name) BoundReference(index, partitionColumns(index).dataType, nullable = true) }, Nil) - val ret = selectedPartitions.filter(p => boundPredicate.eval(p.values)) + var ret = selectedPartitions.filter(p => boundPredicate.eval(p.values)) + if (dynamicDataFilters.nonEmpty) { + val filePruningRunner = new FilePruningRunner(dynamicDataFilters) + ret = ret.map(filePruningRunner.prune) + } setFilesNumAndSizeMetric(ret, false) val timeTakenMs = (System.nanoTime() - startTime) / 1000 / 1000 driverMetrics("pruningTime").set(timeTakenMs) @@ -288,14 +299,6 @@ trait FileSourceScanLike extends DataSourceScanExec { } } - /** - * [[partitionFilters]] can contain subqueries whose results are available only at runtime so - * accessing [[selectedPartitions]] should be guarded by this method during planning - */ - private def hasPartitionsAvailableAtRunTime: Boolean = { - partitionFilters.exists(ExecSubqueryExpression.hasSubquery) - } - private def toAttribute(colName: String): Option[Attribute] = output.find(_.name == colName) @@ -339,8 +342,7 @@ trait FileSourceScanLike extends DataSourceScanExec { spec.sortColumnNames.map(x => toAttribute(x)).takeWhile(x => x.isDefined).map(_.get) val shouldCalculateSortOrder = conf.getConf(SQLConf.LEGACY_BUCKETED_TABLE_SCAN_OUTPUT_ORDERING) && - sortColumns.nonEmpty && - !hasPartitionsAvailableAtRunTime + sortColumns.nonEmpty val sortOrder = if (shouldCalculateSortOrder) { // In case of bucketing, its possible to have multiple files belonging to the @@ -371,35 +373,29 @@ trait FileSourceScanLike extends DataSourceScanExec { } } - private def translatePushedDownFilters(dataFilters: Seq[Expression]): Seq[Filter] = { + private def translateToV1Filters( + dataFilters: Seq[Expression], + scalarSubqueryToLiteral: execution.ScalarSubquery => Literal): Seq[Filter] = { + val scalarSubqueryReplaced = dataFilters.map(_.transform { + // Replace scalar subquery to literal so that `DataSourceStrategy.translateFilter` can + // support translating it. + case scalarSubquery: execution.ScalarSubquery => scalarSubqueryToLiteral(scalarSubquery) + }) + val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation) // `dataFilters` should not include any constant metadata col filters // because the metadata struct has been flatted in FileSourceStrategy // and thus metadata col filters are invalid to be pushed down. Metadata that is generated // during the scan can be used for filters. - dataFilters.filterNot(_.references.exists { + scalarSubqueryReplaced.filterNot(_.references.exists { case FileSourceConstantMetadataAttribute(_) => true case _ => false }).flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) } + // This field may execute subquery expressions and should not be accessed during planning. @transient - protected lazy val pushedDownFilters: Seq[Filter] = translatePushedDownFilters(dataFilters) - - @transient - protected lazy val dynamicallyPushedDownFilters: Seq[Filter] = { - if (dataFilters.exists(_.exists(_.isInstanceOf[execution.ScalarSubquery]))) { - // Replace scalar subquery to literal so that `DataSourceStrategy.translateFilter` can - // support translate it. The subquery must has been materialized since SparkPlan always - // execute subquery first. - val normalized = dataFilters.map(_.transform { - case scalarSubquery: execution.ScalarSubquery => scalarSubquery.toLiteral - }) - translatePushedDownFilters(normalized) - } else { - pushedDownFilters - } - } + protected lazy val pushedDownFilters: Seq[Filter] = translateToV1Filters(dataFilters, _.toLiteral) override lazy val metadata: Map[String, String] = { def seqToString(seq: Seq[Any]) = seq.mkString("[", ", ", "]") @@ -407,13 +403,17 @@ trait FileSourceScanLike extends DataSourceScanExec { val locationDesc = location.getClass.getSimpleName + Utils.buildLocationMetadata(location.rootPaths, maxMetadataValueLength) + // `metadata` is accessed during planning and the scalar subquery is not executed yet. Here + // we get the pretty string of the scalar subquery, for display purpose only. + val pushedFiltersForDisplay = translateToV1Filters( + dataFilters, s => Literal("ScalarSubquery#" + s.exprId.id)) val metadata = Map( "Format" -> relation.fileFormat.toString, "ReadSchema" -> requiredSchema.catalogString, "Batched" -> supportsColumnar.toString, "PartitionFilters" -> seqToString(partitionFilters), - "PushedFilters" -> seqToString(pushedDownFilters), + "PushedFilters" -> seqToString(pushedFiltersForDisplay), "DataFilters" -> seqToString(dataFilters), "Location" -> locationDesc) @@ -561,7 +561,7 @@ case class FileSourceScanExec( dataSchema = relation.dataSchema, partitionSchema = relation.partitionSchema, requiredSchema = requiredSchema, - filters = dynamicallyPushedDownFilters, + filters = pushedDownFilters, options = options, hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 5e6e0ad039258..94c2d2ffaca59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources import java.util.Locale +import scala.collection.immutable.ListMap import scala.collection.mutable import org.apache.hadoop.fs.Path @@ -670,9 +671,10 @@ object DataSourceStrategy // A map from original Catalyst expressions to corresponding translated data source filters. // If a predicate is not in this map, it means it cannot be pushed down. val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation) - val translatedMap: Map[Expression, Filter] = predicates.flatMap { p => + // SPARK-41636: we keep the order of the predicates to avoid CodeGenerator cache misses + val translatedMap: Map[Expression, Filter] = ListMap(predicates.flatMap { p => translateFilter(p, supportNestedPredicatePushdown).map(f => p -> f) - }.toMap + }: _*) val pushedFilters: Seq[Filter] = translatedMap.values.toSeq diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala index 1b28294e94a88..2535440add19a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql.execution.datasources +import scala.collection.mutable + import org.apache.hadoop.fs._ +import org.apache.spark.paths.SparkPath import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.StructType @@ -43,6 +46,58 @@ case class FileStatusWithMetadata(fileStatus: FileStatus, metadata: Map[String, */ case class PartitionDirectory(values: InternalRow, files: Seq[FileStatusWithMetadata]) +/** + * A runner that extracts file metadata filters from the given `filters` and use it to prune files + * in `PartitionDirectory`. + */ +class FilePruningRunner(filters: Seq[Expression]) { + // retrieve the file constant metadata filters and reduce to a final filter expression that can + // be applied to files. + val fileMetadataFilterOpt = filters.filter { f => + f.references.nonEmpty && f.references.forall { + case FileSourceConstantMetadataAttribute(metadataAttr) => + // we only know block start and length after splitting files, so skip it here + metadataAttr.name != FileFormat.FILE_BLOCK_START && + metadataAttr.name != FileFormat.FILE_BLOCK_LENGTH + case _ => false + } + }.reduceOption(And) + + // - Retrieve all required metadata attributes and put them into a sequence + // - Bind all file constant metadata attribute references to their respective index + val requiredMetadataColumnNames: mutable.Buffer[String] = mutable.Buffer.empty + val boundedFilterMetadataStructOpt = fileMetadataFilterOpt.map { fileMetadataFilter => + Predicate.createInterpreted(fileMetadataFilter.transform { + case attr: AttributeReference => + val existingMetadataColumnIndex = requiredMetadataColumnNames.indexOf(attr.name) + val metadataColumnIndex = if (existingMetadataColumnIndex >= 0) { + existingMetadataColumnIndex + } else { + requiredMetadataColumnNames += attr.name + requiredMetadataColumnNames.length - 1 + } + BoundReference(metadataColumnIndex, attr.dataType, nullable = true) + }) + } + + private def matchFileMetadataPredicate(partitionValues: InternalRow, f: FileStatus): Boolean = { + // use option.forall, so if there is no filter no metadata struct, return true + boundedFilterMetadataStructOpt.forall { boundedFilter => + val row = + FileFormat.createMetadataInternalRow(partitionValues, requiredMetadataColumnNames.toSeq, + SparkPath.fromFileStatus(f), f.getLen, f.getModificationTime) + boundedFilter.eval(row) + } + } + + def prune(pd: PartitionDirectory): PartitionDirectory = { + val prunedFiles = pd.files.filter { f => + matchFileMetadataPredicate(InternalRow.empty, f.fileStatus) + } + pd.copy(files = prunedFiles) + } +} + object PartitionDirectory { // For backward compat with code that does not know about extra file metadata def apply(values: InternalRow, files: Array[FileStatus]): PartitionDirectory = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 5673e12927c70..551fe253657c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -174,11 +174,7 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging { val bucketSpec: Option[BucketSpec] = fsRelation.bucketSpec val bucketSet = if (shouldPruneBuckets(bucketSpec)) { - // subquery expressions are filtered out because they can't be used to prune buckets - // as data filters, yet they would be executed - val normalizedFiltersWithoutSubqueries = - normalizedFilters.filterNot(SubqueryExpression.hasSubquery) - genBucketSet(normalizedFiltersWithoutSubqueries, bucketSpec.get) + genBucketSet(normalizedFilters, bucketSpec.get) } else { None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index b25162aad9a77..ef4fff2360097 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.execution.datasources.FileFormat.createMetadataInternalRow import org.apache.spark.sql.types.StructType /** @@ -73,48 +72,10 @@ abstract class PartitioningAwareFileIndex( isDataPath(f.getPath) && f.getLen > 0 } - // retrieve the file constant metadata filters and reduce to a final filter expression that can - // be applied to files. - val fileMetadataFilterOpt = dataFilters.filter { f => - f.references.nonEmpty && f.references.forall { - case FileSourceConstantMetadataAttribute(metadataAttr) => - // we only know block start and length after splitting files, so skip it here - metadataAttr.name != FileFormat.FILE_BLOCK_START && - metadataAttr.name != FileFormat.FILE_BLOCK_LENGTH - case _ => false - } - }.reduceOption(expressions.And) - - // - Retrieve all required metadata attributes and put them into a sequence - // - Bind all file constant metadata attribute references to their respective index - val requiredMetadataColumnNames: mutable.Buffer[String] = mutable.Buffer.empty - val boundedFilterMetadataStructOpt = fileMetadataFilterOpt.map { fileMetadataFilter => - Predicate.createInterpreted(fileMetadataFilter.transform { - case attr: AttributeReference => - val existingMetadataColumnIndex = requiredMetadataColumnNames.indexOf(attr.name) - val metadataColumnIndex = if (existingMetadataColumnIndex >= 0) { - existingMetadataColumnIndex - } else { - requiredMetadataColumnNames += attr.name - requiredMetadataColumnNames.length - 1 - } - BoundReference(metadataColumnIndex, attr.dataType, nullable = true) - }) - } - - def matchFileMetadataPredicate(partitionValues: InternalRow, f: FileStatus): Boolean = { - // use option.forall, so if there is no filter no metadata struct, return true - boundedFilterMetadataStructOpt.forall { boundedFilter => - val row = - createMetadataInternalRow(partitionValues, requiredMetadataColumnNames.toSeq, - SparkPath.fromFileStatus(f), f.getLen, f.getModificationTime) - boundedFilter.eval(row) - } - } - + val filePruningRunner = new FilePruningRunner(dataFilters) val selectedPartitions = if (partitionSpec().partitionColumns.isEmpty) { - PartitionDirectory(InternalRow.empty, allFiles().toArray - .filter(f => isNonEmptyFile(f) && matchFileMetadataPredicate(InternalRow.empty, f))) :: Nil + filePruningRunner.prune( + PartitionDirectory(InternalRow.empty, allFiles().toArray.filter(isNonEmptyFile))) :: Nil } else { if (recursiveFileLookup) { throw new IllegalArgumentException( @@ -125,14 +86,13 @@ abstract class PartitioningAwareFileIndex( val files: Seq[FileStatus] = leafDirToChildrenFiles.get(path) match { case Some(existingDir) => // Directory has children files in it, return them - existingDir.filter(f => matchPathPattern(f) && isNonEmptyFile(f) && - matchFileMetadataPredicate(values, f)) + existingDir.filter(f => matchPathPattern(f) && isNonEmptyFile(f)) case None => // Directory does not exist, or has no children files Nil } - PartitionDirectory(values, files.toArray) + filePruningRunner.prune(PartitionDirectory(values, files.toArray)) } } logTrace("Selected files after partition pruning:\n\t" + selectedPartitions.mkString("\n\t")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 4b53819739262..eba3c71f871e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -190,10 +190,17 @@ case class BatchScanExec( Seq.fill(numSplits)(Seq.empty)) } } else { + // either `commonPartitionValues` is not defined, or it is defined but + // `applyPartialClustering` is false. val partitionMapping = groupedPartitions.map { case (row, parts) => InternalRowComparableWrapper(row, p.expressions) -> parts }.toMap - finalPartitions = p.partitionValues.map { partValue => + + // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there + // could exist duplicated partition values, as partition grouping is not done + // at the beginning and postponed to this method. It is important to use unique + // partition values here so that grouped partitions won't get duplicated. + finalPartitions = p.uniquePartitionValues.map { partValue => // Use empty partition for those partition values that are not present partitionMapping.getOrElse( InternalRowComparableWrapper(partValue, p.expressions), Seq.empty) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala index a7d1edefcd611..6496f9a0006e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala @@ -79,7 +79,7 @@ trait JoinCodegenSupport extends CodegenSupport with BaseJoinExec { setDefaultValue: Boolean): Seq[ExprCode] = { ctx.currentVars = null ctx.INPUT_ROW = row - plan.output.zipWithIndex.map { case (a, i) => + plan.output.toIndexedSeq.zipWithIndex.map { case (a, i) => val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx) if (setDefaultValue) { // the variables are needed even there is no matched rows diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 0241f683d6902..8d49b1558d687 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -556,14 +556,18 @@ case class SortMergeJoinExec( val doJoin = joinType match { case _: InnerLike => + val cleanedFlag = + ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "cleanedFlag", v => s"$v = false;") codegenInner(findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck, outputRow, - eagerCleanup) + eagerCleanup, cleanedFlag) case LeftOuter | RightOuter => codegenOuter(streamedInput, findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck, ctx.freshName("hasOutputRow"), outputRow, eagerCleanup) case LeftSemi => + val cleanedFlag = + ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "cleanedFlag", v => s"$v = false;") codegenSemi(findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck, - ctx.freshName("hasOutputRow"), outputRow, eagerCleanup) + ctx.freshName("hasOutputRow"), outputRow, eagerCleanup, cleanedFlag) case LeftAnti => codegenAnti(streamedInput, findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck, loadStreamed, ctx.freshName("hasMatchedRow"), outputRow, eagerCleanup) @@ -606,8 +610,13 @@ case class SortMergeJoinExec( bufferedRow: String, conditionCheck: String, outputRow: String, - eagerCleanup: String): String = { + eagerCleanup: String, + cleanedFlag: String): String = { s""" + |if($cleanedFlag) { + | return; + |} + | |while ($findNextJoinRows) { | $beforeLoop | while ($matchIterator.hasNext()) { @@ -617,6 +626,7 @@ case class SortMergeJoinExec( | } | if (shouldStop()) return; |} + |$cleanedFlag = true; |$eagerCleanup """.stripMargin } @@ -665,8 +675,13 @@ case class SortMergeJoinExec( conditionCheck: String, hasOutputRow: String, outputRow: String, - eagerCleanup: String): String = { + eagerCleanup: String, + cleanedFlag: String): String = { s""" + |if($cleanedFlag) { + | return; + |} + | |while ($findNextJoinRows) { | $beforeLoop | boolean $hasOutputRow = false; @@ -679,6 +694,7 @@ case class SortMergeJoinExec( | } | if (shouldStop()) return; |} + |$cleanedFlag = true; |$eagerCleanup """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index 05239d8d16462..36cb2e17835a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -175,6 +175,7 @@ object UserDefinedPythonTableFunction { val reuseWorker = env.conf.get(PYTHON_WORKER_REUSE) val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback + val workerMemoryMb = SQLConf.get.pythonUDTFAnalyzerMemory val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) @@ -192,6 +193,9 @@ object UserDefinedPythonTableFunction { if (simplifiedTraceback) { envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1") } + workerMemoryMb.foreach { memoryMb => + envVars.put("PYSPARK_UDTF_ANALYZER_MEMORY_MB", memoryMb.toString) + } envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString) envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala deleted file mode 100644 index e6d1381b2b620..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.util.concurrent.TimeUnit - -import scala.concurrent.duration.Duration - -import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_DAY -import org.apache.spark.sql.catalyst.util.DateTimeUtils.microsToMillis -import org.apache.spark.sql.catalyst.util.IntervalUtils -import org.apache.spark.sql.streaming.Trigger -import org.apache.spark.unsafe.types.UTF8String - -private object Triggers { - def validate(intervalMs: Long): Unit = { - require(intervalMs >= 0, "the interval of trigger should not be negative") - } - - def convert(interval: String): Long = { - val cal = IntervalUtils.stringToInterval(UTF8String.fromString(interval)) - if (cal.months != 0) { - throw new IllegalArgumentException(s"Doesn't support month or year interval: $interval") - } - val microsInDays = Math.multiplyExact(cal.days, MICROS_PER_DAY) - microsToMillis(Math.addExact(cal.microseconds, microsInDays)) - } - - def convert(interval: Duration): Long = interval.toMillis - - def convert(interval: Long, unit: TimeUnit): Long = unit.toMillis(interval) -} - -/** - * A [[Trigger]] that processes all available data in one batch then terminates the query. - */ -case object OneTimeTrigger extends Trigger - -/** - * A [[Trigger]] that processes all available data in multiple batches then terminates the query. - */ -case object AvailableNowTrigger extends Trigger - -/** - * A [[Trigger]] that runs a query periodically based on the processing time. If `interval` is 0, - * the query will run as fast as possible. - */ -case class ProcessingTimeTrigger(intervalMs: Long) extends Trigger { - Triggers.validate(intervalMs) -} - -object ProcessingTimeTrigger { - import Triggers._ - - def apply(interval: String): ProcessingTimeTrigger = { - ProcessingTimeTrigger(convert(interval)) - } - - def apply(interval: Duration): ProcessingTimeTrigger = { - ProcessingTimeTrigger(convert(interval)) - } - - def create(interval: String): ProcessingTimeTrigger = { - apply(interval) - } - - def create(interval: Long, unit: TimeUnit): ProcessingTimeTrigger = { - ProcessingTimeTrigger(convert(interval, unit)) - } -} - -/** - * A [[Trigger]] that continuously processes streaming data, asynchronously checkpointing at - * the specified interval. - */ -case class ContinuousTrigger(intervalMs: Long) extends Trigger { - Triggers.validate(intervalMs) -} - -object ContinuousTrigger { - import Triggers._ - - def apply(interval: String): ContinuousTrigger = { - ContinuousTrigger(convert(interval)) - } - - def apply(interval: Duration): ContinuousTrigger = { - ContinuousTrigger(convert(interval)) - } - - def create(interval: String): ContinuousTrigger = { - apply(interval) - } - - def create(interval: Long, unit: TimeUnit): ContinuousTrigger = { - ContinuousTrigger(convert(interval, unit)) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index d4366fe732be4..a2868df941178 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -611,8 +611,11 @@ class RocksDB( if (log.isWarnEnabled) dbLogLevel = InfoLogLevel.WARN_LEVEL if (log.isInfoEnabled) dbLogLevel = InfoLogLevel.INFO_LEVEL if (log.isDebugEnabled) dbLogLevel = InfoLogLevel.DEBUG_LEVEL - dbOptions.setLogger(dbLogger) + dbLogger.setInfoLogLevel(dbLogLevel) + // The log level set in dbLogger is effective and the one to dbOptions isn't applied to + // customized logger. We still set it as it might show up in RocksDB config file or logging. dbOptions.setInfoLogLevel(dbLogLevel) + dbOptions.setLogger(dbLogger) logInfo(s"Set RocksDB native logging level to $dbLogLevel") dbLogger } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 48fd009d6e70f..4f7cf8da78722 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -42,6 +42,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.*; import org.apache.spark.sql.*; +import static org.apache.spark.sql.RowFactory.create; import org.apache.spark.sql.catalyst.encoders.OuterScopes; import org.apache.spark.sql.catalyst.expressions.GenericRow; import org.apache.spark.sql.test.TestSparkSession; @@ -1956,6 +1957,24 @@ public void testSpecificLists() { Assert.assertEquals(beans, dataset.collectAsList()); } + @Test + public void testRowEncoder() { + final StructType schema = new StructType() + .add("a", "int") + .add("b", "string"); + final Dataset df = spark.range(3) + .map(new MapFunction() { + @Override + public Row call(Long i) { + return create(i.intValue(), "s" + i); + } + }, + Encoders.row(schema)) + .filter(col("a").geq(1)); + final List expected = Arrays.asList(create(1, "s1"), create(2, "s2")); + Assert.assertEquals(expected, df.collectAsList()); + } + public static class SpecificListsBean implements Serializable { private ArrayList arrayList; private LinkedList linkedList; diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/named-function-arguments.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/named-function-arguments.sql.out index 650b61b419245..11e2651c6f225 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/named-function-arguments.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/named-function-arguments.sql.out @@ -229,7 +229,8 @@ org.apache.spark.sql.AnalysisException "errorClass" : "UNEXPECTED_POSITIONAL_ARGUMENT", "sqlState" : "4274K", "messageParameters" : { - "functionName" : "`mask`" + "functionName" : "`mask`", + "parameterName" : "`lowerChar`" }, "queryContext" : [ { "objectType" : "", @@ -292,6 +293,7 @@ org.apache.spark.sql.AnalysisException "sqlState" : "4274K", "messageParameters" : { "functionName" : "`mask`", + "index" : "0", "parameterName" : "`str`" }, "queryContext" : [ { diff --git a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out index 44b2679f89d86..7dfaaea46b75d 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out @@ -398,7 +398,7 @@ AdaptiveSparkPlan (3) Output [2]: [key#x, val#x] Batched: true Location [not included in comparison]/{warehouse_dir}/explain_temp1] -PushedFilters: [IsNotNull(key), IsNotNull(val), GreaterThan(val,3)] +PushedFilters: [IsNotNull(key), IsNotNull(val), EqualTo(key,ScalarSubquery#x), GreaterThan(val,3)] ReadSchema: struct (2) Filter @@ -425,7 +425,7 @@ AdaptiveSparkPlan (10) Output [2]: [key#x, val#x] Batched: true Location [not included in comparison]/{warehouse_dir}/explain_temp2] -PushedFilters: [IsNotNull(key), IsNotNull(val), EqualTo(val,2)] +PushedFilters: [IsNotNull(key), IsNotNull(val), EqualTo(key,ScalarSubquery#x), EqualTo(val,2)] ReadSchema: struct (5) Filter diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out index 0cd94abc9b307..ef4d57735aa39 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out @@ -405,7 +405,7 @@ struct Output [2]: [key#x, val#x] Batched: true Location [not included in comparison]/{warehouse_dir}/explain_temp1] -PushedFilters: [IsNotNull(key), IsNotNull(val), GreaterThan(val,3)] +PushedFilters: [IsNotNull(key), IsNotNull(val), EqualTo(key,ScalarSubquery#x), GreaterThan(val,3)] ReadSchema: struct (2) ColumnarToRow [codegen id : 1] @@ -433,7 +433,7 @@ Subquery:2 Hosting operator id = 1 Hosting Expression = Subquery scalar-subquery Output [2]: [key#x, val#x] Batched: true Location [not included in comparison]/{warehouse_dir}/explain_temp2] -PushedFilters: [IsNotNull(key), IsNotNull(val), EqualTo(val,2)] +PushedFilters: [IsNotNull(key), IsNotNull(val), EqualTo(key,ScalarSubquery#x), EqualTo(val,2)] ReadSchema: struct (5) ColumnarToRow [codegen id : 1] diff --git a/sql/core/src/test/resources/sql-tests/results/named-function-arguments.sql.out b/sql/core/src/test/resources/sql-tests/results/named-function-arguments.sql.out index 77c15b56c8dab..60301862a35c9 100644 --- a/sql/core/src/test/resources/sql-tests/results/named-function-arguments.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/named-function-arguments.sql.out @@ -214,7 +214,8 @@ org.apache.spark.sql.AnalysisException "errorClass" : "UNEXPECTED_POSITIONAL_ARGUMENT", "sqlState" : "4274K", "messageParameters" : { - "functionName" : "`mask`" + "functionName" : "`mask`", + "parameterName" : "`lowerChar`" }, "queryContext" : [ { "objectType" : "", @@ -283,6 +284,7 @@ org.apache.spark.sql.AnalysisException "sqlState" : "4274K", "messageParameters" : { "functionName" : "`mask`", + "index" : "0", "parameterName" : "`str`" }, "queryContext" : [ { diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b.sf100/explain.txt index 16bdfb1041619..0986e92088caa 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b.sf100/explain.txt @@ -648,7 +648,7 @@ BroadcastExchange (114) Output [2]: [d_date_sk#36, d_week_seq#100] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_week_seq), IsNotNull(d_date_sk)] +PushedFilters: [IsNotNull(d_week_seq), EqualTo(d_week_seq,ScalarSubquery#101), IsNotNull(d_date_sk)] ReadSchema: struct (111) ColumnarToRow [codegen id : 1] @@ -741,7 +741,7 @@ BroadcastExchange (128) Output [2]: [d_date_sk#60, d_week_seq#108] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_week_seq), IsNotNull(d_date_sk)] +PushedFilters: [IsNotNull(d_week_seq), EqualTo(d_week_seq,ScalarSubquery#109), IsNotNull(d_date_sk)] ReadSchema: struct (125) ColumnarToRow [codegen id : 1] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b/explain.txt index cc8b88f3adcbf..3f4f3653371d9 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14b/explain.txt @@ -618,7 +618,7 @@ BroadcastExchange (108) Output [2]: [d_date_sk#40, d_week_seq#100] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_week_seq), IsNotNull(d_date_sk)] +PushedFilters: [IsNotNull(d_week_seq), EqualTo(d_week_seq,ScalarSubquery#101), IsNotNull(d_date_sk)] ReadSchema: struct (105) ColumnarToRow [codegen id : 1] @@ -711,7 +711,7 @@ BroadcastExchange (122) Output [2]: [d_date_sk#64, d_week_seq#108] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_week_seq), IsNotNull(d_date_sk)] +PushedFilters: [IsNotNull(d_week_seq), EqualTo(d_week_seq,ScalarSubquery#109), IsNotNull(d_date_sk)] ReadSchema: struct (119) ColumnarToRow [codegen id : 1] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q54.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q54.sf100/explain.txt index 19643cccab639..572452c72529e 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q54.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q54.sf100/explain.txt @@ -387,7 +387,7 @@ BroadcastExchange (69) Output [2]: [d_date_sk#29, d_month_seq#41] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_month_seq), IsNotNull(d_date_sk)] +PushedFilters: [IsNotNull(d_month_seq), GreaterThanOrEqual(d_month_seq,ScalarSubquery#42), LessThanOrEqual(d_month_seq,ScalarSubquery#43), IsNotNull(d_date_sk)] ReadSchema: struct (66) ColumnarToRow [codegen id : 1] @@ -395,7 +395,7 @@ Input [2]: [d_date_sk#29, d_month_seq#41] (67) Filter [codegen id : 1] Input [2]: [d_date_sk#29, d_month_seq#41] -Condition : (((isnotnull(d_month_seq#41) AND (d_month_seq#41 >= ReusedSubquery Subquery scalar-subquery#42, [id=#43])) AND (d_month_seq#41 <= ReusedSubquery Subquery scalar-subquery#44, [id=#45])) AND isnotnull(d_date_sk#29)) +Condition : (((isnotnull(d_month_seq#41) AND (d_month_seq#41 >= ReusedSubquery Subquery scalar-subquery#42, [id=#44])) AND (d_month_seq#41 <= ReusedSubquery Subquery scalar-subquery#43, [id=#45])) AND isnotnull(d_date_sk#29)) (68) Project [codegen id : 1] Output [1]: [d_date_sk#29] @@ -405,11 +405,11 @@ Input [2]: [d_date_sk#29, d_month_seq#41] Input [1]: [d_date_sk#29] Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [plan_id=9] -Subquery:4 Hosting operator id = 67 Hosting Expression = ReusedSubquery Subquery scalar-subquery#42, [id=#43] +Subquery:4 Hosting operator id = 67 Hosting Expression = ReusedSubquery Subquery scalar-subquery#42, [id=#44] -Subquery:5 Hosting operator id = 67 Hosting Expression = ReusedSubquery Subquery scalar-subquery#44, [id=#45] +Subquery:5 Hosting operator id = 67 Hosting Expression = ReusedSubquery Subquery scalar-subquery#43, [id=#45] -Subquery:6 Hosting operator id = 65 Hosting Expression = Subquery scalar-subquery#42, [id=#43] +Subquery:6 Hosting operator id = 65 Hosting Expression = Subquery scalar-subquery#42, [id=#44] * HashAggregate (76) +- Exchange (75) +- * HashAggregate (74) @@ -455,7 +455,7 @@ Functions: [] Aggregate Attributes: [] Results [1]: [(d_month_seq + 1)#49] -Subquery:7 Hosting operator id = 65 Hosting Expression = Subquery scalar-subquery#44, [id=#45] +Subquery:7 Hosting operator id = 65 Hosting Expression = Subquery scalar-subquery#43, [id=#45] * HashAggregate (83) +- Exchange (82) +- * HashAggregate (81) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q54/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q54/explain.txt index cefaff0c09d39..502d4f3ee6ab3 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q54/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q54/explain.txt @@ -372,7 +372,7 @@ BroadcastExchange (66) Output [2]: [d_date_sk#29, d_month_seq#41] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_month_seq), IsNotNull(d_date_sk)] +PushedFilters: [IsNotNull(d_month_seq), GreaterThanOrEqual(d_month_seq,ScalarSubquery#42), LessThanOrEqual(d_month_seq,ScalarSubquery#43), IsNotNull(d_date_sk)] ReadSchema: struct (63) ColumnarToRow [codegen id : 1] @@ -380,7 +380,7 @@ Input [2]: [d_date_sk#29, d_month_seq#41] (64) Filter [codegen id : 1] Input [2]: [d_date_sk#29, d_month_seq#41] -Condition : (((isnotnull(d_month_seq#41) AND (d_month_seq#41 >= ReusedSubquery Subquery scalar-subquery#42, [id=#43])) AND (d_month_seq#41 <= ReusedSubquery Subquery scalar-subquery#44, [id=#45])) AND isnotnull(d_date_sk#29)) +Condition : (((isnotnull(d_month_seq#41) AND (d_month_seq#41 >= ReusedSubquery Subquery scalar-subquery#42, [id=#44])) AND (d_month_seq#41 <= ReusedSubquery Subquery scalar-subquery#43, [id=#45])) AND isnotnull(d_date_sk#29)) (65) Project [codegen id : 1] Output [1]: [d_date_sk#29] @@ -390,11 +390,11 @@ Input [2]: [d_date_sk#29, d_month_seq#41] Input [1]: [d_date_sk#29] Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [plan_id=10] -Subquery:4 Hosting operator id = 64 Hosting Expression = ReusedSubquery Subquery scalar-subquery#42, [id=#43] +Subquery:4 Hosting operator id = 64 Hosting Expression = ReusedSubquery Subquery scalar-subquery#42, [id=#44] -Subquery:5 Hosting operator id = 64 Hosting Expression = ReusedSubquery Subquery scalar-subquery#44, [id=#45] +Subquery:5 Hosting operator id = 64 Hosting Expression = ReusedSubquery Subquery scalar-subquery#43, [id=#45] -Subquery:6 Hosting operator id = 62 Hosting Expression = Subquery scalar-subquery#42, [id=#43] +Subquery:6 Hosting operator id = 62 Hosting Expression = Subquery scalar-subquery#42, [id=#44] * HashAggregate (73) +- Exchange (72) +- * HashAggregate (71) @@ -440,7 +440,7 @@ Functions: [] Aggregate Attributes: [] Results [1]: [(d_month_seq + 1)#49] -Subquery:7 Hosting operator id = 62 Hosting Expression = Subquery scalar-subquery#44, [id=#45] +Subquery:7 Hosting operator id = 62 Hosting Expression = Subquery scalar-subquery#43, [id=#45] * HashAggregate (80) +- Exchange (79) +- * HashAggregate (78) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q58.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q58.sf100/explain.txt index 26ffe2e0b323e..d9083741a88e7 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q58.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q58.sf100/explain.txt @@ -320,7 +320,7 @@ Condition : isnotnull(d_date_sk#5) Output [2]: [d_date#40, d_week_seq#41] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_week_seq)] +PushedFilters: [IsNotNull(d_week_seq), EqualTo(d_week_seq,ScalarSubquery#42)] ReadSchema: struct (54) ColumnarToRow [codegen id : 1] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q58/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q58/explain.txt index cdb5e45f66872..7f95e52cb8df5 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q58/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q58/explain.txt @@ -320,7 +320,7 @@ Condition : isnotnull(d_date_sk#7) Output [2]: [d_date#40, d_week_seq#41] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_week_seq)] +PushedFilters: [IsNotNull(d_week_seq), EqualTo(d_week_seq,ScalarSubquery#42)] ReadSchema: struct (54) ColumnarToRow [codegen id : 1] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q6.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q6.sf100/explain.txt index 93db1e57839df..ac69497fb26ca 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q6.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q6.sf100/explain.txt @@ -272,7 +272,7 @@ BroadcastExchange (50) Output [2]: [d_date_sk#16, d_month_seq#26] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_month_seq), IsNotNull(d_date_sk)] +PushedFilters: [IsNotNull(d_month_seq), EqualTo(d_month_seq,ScalarSubquery#27), IsNotNull(d_date_sk)] ReadSchema: struct (47) ColumnarToRow [codegen id : 1] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q6/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q6/explain.txt index bd5bdfb666100..75644fea091fe 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q6/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q6/explain.txt @@ -242,7 +242,7 @@ BroadcastExchange (44) Output [2]: [d_date_sk#9, d_month_seq#26] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_month_seq), IsNotNull(d_date_sk)] +PushedFilters: [IsNotNull(d_month_seq), EqualTo(d_month_seq,ScalarSubquery#27), IsNotNull(d_date_sk)] ReadSchema: struct (41) ColumnarToRow [codegen id : 1] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q64/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q64/explain.txt index 620bab62bf16d..69023c88202af 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q64/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q64/explain.txt @@ -760,15 +760,15 @@ Input [4]: [cs_item_sk#122, sum#123, sum#124, isEmpty#125] Keys [1]: [cs_item_sk#122] Functions [2]: [sum(UnscaledValue(cs_ext_list_price#126)), sum(((cr_refunded_cash#127 + cr_reversed_charge#128) + cr_store_credit#129))] Aggregate Attributes [2]: [sum(UnscaledValue(cs_ext_list_price#126))#33, sum(((cr_refunded_cash#127 + cr_reversed_charge#128) + cr_store_credit#129))#34] -Results [3]: [cs_item_sk#122, MakeDecimal(sum(UnscaledValue(cs_ext_list_price#126))#33,17,2) AS sale#35, sum(((cr_refunded_cash#127 + cr_reversed_charge#128) + cr_store_credit#129))#34 AS refund#36] +Results [3]: [cs_item_sk#122, MakeDecimal(sum(UnscaledValue(cs_ext_list_price#126))#33,17,2) AS sale#130, sum(((cr_refunded_cash#127 + cr_reversed_charge#128) + cr_store_credit#129))#34 AS refund#131] (126) Filter [codegen id : 35] -Input [3]: [cs_item_sk#122, sale#35, refund#36] -Condition : ((isnotnull(sale#35) AND isnotnull(refund#36)) AND (cast(sale#35 as decimal(21,2)) > (2 * refund#36))) +Input [3]: [cs_item_sk#122, sale#130, refund#131] +Condition : ((isnotnull(sale#130) AND isnotnull(refund#131)) AND (cast(sale#130 as decimal(21,2)) > (2 * refund#131))) (127) Project [codegen id : 35] Output [1]: [cs_item_sk#122] -Input [3]: [cs_item_sk#122, sale#35, refund#36] +Input [3]: [cs_item_sk#122, sale#130, refund#131] (128) Sort [codegen id : 35] Input [1]: [cs_item_sk#122] @@ -785,239 +785,239 @@ Output [11]: [ss_item_sk#106, ss_customer_sk#107, ss_cdemo_sk#108, ss_hdemo_sk#1 Input [12]: [ss_item_sk#106, ss_customer_sk#107, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_store_sk#111, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, ss_sold_date_sk#117, cs_item_sk#122] (131) ReusedExchange [Reuses operator id: 191] -Output [2]: [d_date_sk#130, d_year#131] +Output [2]: [d_date_sk#132, d_year#133] (132) BroadcastHashJoin [codegen id : 51] Left keys [1]: [ss_sold_date_sk#117] -Right keys [1]: [d_date_sk#130] +Right keys [1]: [d_date_sk#132] Join type: Inner Join condition: None (133) Project [codegen id : 51] -Output [11]: [ss_item_sk#106, ss_customer_sk#107, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_store_sk#111, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131] -Input [13]: [ss_item_sk#106, ss_customer_sk#107, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_store_sk#111, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, ss_sold_date_sk#117, d_date_sk#130, d_year#131] +Output [11]: [ss_item_sk#106, ss_customer_sk#107, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_store_sk#111, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133] +Input [13]: [ss_item_sk#106, ss_customer_sk#107, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_store_sk#111, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, ss_sold_date_sk#117, d_date_sk#132, d_year#133] (134) ReusedExchange [Reuses operator id: 41] -Output [3]: [s_store_sk#132, s_store_name#133, s_zip#134] +Output [3]: [s_store_sk#134, s_store_name#135, s_zip#136] (135) BroadcastHashJoin [codegen id : 51] Left keys [1]: [ss_store_sk#111] -Right keys [1]: [s_store_sk#132] +Right keys [1]: [s_store_sk#134] Join type: Inner Join condition: None (136) Project [codegen id : 51] -Output [12]: [ss_item_sk#106, ss_customer_sk#107, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134] -Input [14]: [ss_item_sk#106, ss_customer_sk#107, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_store_sk#111, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_sk#132, s_store_name#133, s_zip#134] +Output [12]: [ss_item_sk#106, ss_customer_sk#107, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136] +Input [14]: [ss_item_sk#106, ss_customer_sk#107, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_store_sk#111, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_sk#134, s_store_name#135, s_zip#136] (137) ReusedExchange [Reuses operator id: 47] -Output [6]: [c_customer_sk#135, c_current_cdemo_sk#136, c_current_hdemo_sk#137, c_current_addr_sk#138, c_first_shipto_date_sk#139, c_first_sales_date_sk#140] +Output [6]: [c_customer_sk#137, c_current_cdemo_sk#138, c_current_hdemo_sk#139, c_current_addr_sk#140, c_first_shipto_date_sk#141, c_first_sales_date_sk#142] (138) BroadcastHashJoin [codegen id : 51] Left keys [1]: [ss_customer_sk#107] -Right keys [1]: [c_customer_sk#135] +Right keys [1]: [c_customer_sk#137] Join type: Inner Join condition: None (139) Project [codegen id : 51] -Output [16]: [ss_item_sk#106, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_cdemo_sk#136, c_current_hdemo_sk#137, c_current_addr_sk#138, c_first_shipto_date_sk#139, c_first_sales_date_sk#140] -Input [18]: [ss_item_sk#106, ss_customer_sk#107, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_customer_sk#135, c_current_cdemo_sk#136, c_current_hdemo_sk#137, c_current_addr_sk#138, c_first_shipto_date_sk#139, c_first_sales_date_sk#140] +Output [16]: [ss_item_sk#106, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_cdemo_sk#138, c_current_hdemo_sk#139, c_current_addr_sk#140, c_first_shipto_date_sk#141, c_first_sales_date_sk#142] +Input [18]: [ss_item_sk#106, ss_customer_sk#107, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_customer_sk#137, c_current_cdemo_sk#138, c_current_hdemo_sk#139, c_current_addr_sk#140, c_first_shipto_date_sk#141, c_first_sales_date_sk#142] (140) ReusedExchange [Reuses operator id: 53] -Output [2]: [d_date_sk#141, d_year#142] +Output [2]: [d_date_sk#143, d_year#144] (141) BroadcastHashJoin [codegen id : 51] -Left keys [1]: [c_first_sales_date_sk#140] -Right keys [1]: [d_date_sk#141] +Left keys [1]: [c_first_sales_date_sk#142] +Right keys [1]: [d_date_sk#143] Join type: Inner Join condition: None (142) Project [codegen id : 51] -Output [16]: [ss_item_sk#106, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_cdemo_sk#136, c_current_hdemo_sk#137, c_current_addr_sk#138, c_first_shipto_date_sk#139, d_year#142] -Input [18]: [ss_item_sk#106, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_cdemo_sk#136, c_current_hdemo_sk#137, c_current_addr_sk#138, c_first_shipto_date_sk#139, c_first_sales_date_sk#140, d_date_sk#141, d_year#142] +Output [16]: [ss_item_sk#106, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_cdemo_sk#138, c_current_hdemo_sk#139, c_current_addr_sk#140, c_first_shipto_date_sk#141, d_year#144] +Input [18]: [ss_item_sk#106, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_cdemo_sk#138, c_current_hdemo_sk#139, c_current_addr_sk#140, c_first_shipto_date_sk#141, c_first_sales_date_sk#142, d_date_sk#143, d_year#144] (143) ReusedExchange [Reuses operator id: 53] -Output [2]: [d_date_sk#143, d_year#144] +Output [2]: [d_date_sk#145, d_year#146] (144) BroadcastHashJoin [codegen id : 51] -Left keys [1]: [c_first_shipto_date_sk#139] -Right keys [1]: [d_date_sk#143] +Left keys [1]: [c_first_shipto_date_sk#141] +Right keys [1]: [d_date_sk#145] Join type: Inner Join condition: None (145) Project [codegen id : 51] -Output [16]: [ss_item_sk#106, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_cdemo_sk#136, c_current_hdemo_sk#137, c_current_addr_sk#138, d_year#142, d_year#144] -Input [18]: [ss_item_sk#106, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_cdemo_sk#136, c_current_hdemo_sk#137, c_current_addr_sk#138, c_first_shipto_date_sk#139, d_year#142, d_date_sk#143, d_year#144] +Output [16]: [ss_item_sk#106, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_cdemo_sk#138, c_current_hdemo_sk#139, c_current_addr_sk#140, d_year#144, d_year#146] +Input [18]: [ss_item_sk#106, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_cdemo_sk#138, c_current_hdemo_sk#139, c_current_addr_sk#140, c_first_shipto_date_sk#141, d_year#144, d_date_sk#145, d_year#146] (146) ReusedExchange [Reuses operator id: 62] -Output [2]: [cd_demo_sk#145, cd_marital_status#146] +Output [2]: [cd_demo_sk#147, cd_marital_status#148] (147) BroadcastHashJoin [codegen id : 51] Left keys [1]: [ss_cdemo_sk#108] -Right keys [1]: [cd_demo_sk#145] +Right keys [1]: [cd_demo_sk#147] Join type: Inner Join condition: None (148) Project [codegen id : 51] -Output [16]: [ss_item_sk#106, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_cdemo_sk#136, c_current_hdemo_sk#137, c_current_addr_sk#138, d_year#142, d_year#144, cd_marital_status#146] -Input [18]: [ss_item_sk#106, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_cdemo_sk#136, c_current_hdemo_sk#137, c_current_addr_sk#138, d_year#142, d_year#144, cd_demo_sk#145, cd_marital_status#146] +Output [16]: [ss_item_sk#106, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_cdemo_sk#138, c_current_hdemo_sk#139, c_current_addr_sk#140, d_year#144, d_year#146, cd_marital_status#148] +Input [18]: [ss_item_sk#106, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_cdemo_sk#138, c_current_hdemo_sk#139, c_current_addr_sk#140, d_year#144, d_year#146, cd_demo_sk#147, cd_marital_status#148] (149) ReusedExchange [Reuses operator id: 62] -Output [2]: [cd_demo_sk#147, cd_marital_status#148] +Output [2]: [cd_demo_sk#149, cd_marital_status#150] (150) BroadcastHashJoin [codegen id : 51] -Left keys [1]: [c_current_cdemo_sk#136] -Right keys [1]: [cd_demo_sk#147] +Left keys [1]: [c_current_cdemo_sk#138] +Right keys [1]: [cd_demo_sk#149] Join type: Inner -Join condition: NOT (cd_marital_status#146 = cd_marital_status#148) +Join condition: NOT (cd_marital_status#148 = cd_marital_status#150) (151) Project [codegen id : 51] -Output [14]: [ss_item_sk#106, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_hdemo_sk#137, c_current_addr_sk#138, d_year#142, d_year#144] -Input [18]: [ss_item_sk#106, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_cdemo_sk#136, c_current_hdemo_sk#137, c_current_addr_sk#138, d_year#142, d_year#144, cd_marital_status#146, cd_demo_sk#147, cd_marital_status#148] +Output [14]: [ss_item_sk#106, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_hdemo_sk#139, c_current_addr_sk#140, d_year#144, d_year#146] +Input [18]: [ss_item_sk#106, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_cdemo_sk#138, c_current_hdemo_sk#139, c_current_addr_sk#140, d_year#144, d_year#146, cd_marital_status#148, cd_demo_sk#149, cd_marital_status#150] (152) ReusedExchange [Reuses operator id: 71] -Output [1]: [p_promo_sk#149] +Output [1]: [p_promo_sk#151] (153) BroadcastHashJoin [codegen id : 51] Left keys [1]: [ss_promo_sk#112] -Right keys [1]: [p_promo_sk#149] +Right keys [1]: [p_promo_sk#151] Join type: Inner Join condition: None (154) Project [codegen id : 51] -Output [13]: [ss_item_sk#106, ss_hdemo_sk#109, ss_addr_sk#110, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_hdemo_sk#137, c_current_addr_sk#138, d_year#142, d_year#144] -Input [15]: [ss_item_sk#106, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_hdemo_sk#137, c_current_addr_sk#138, d_year#142, d_year#144, p_promo_sk#149] +Output [13]: [ss_item_sk#106, ss_hdemo_sk#109, ss_addr_sk#110, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_hdemo_sk#139, c_current_addr_sk#140, d_year#144, d_year#146] +Input [15]: [ss_item_sk#106, ss_hdemo_sk#109, ss_addr_sk#110, ss_promo_sk#112, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_hdemo_sk#139, c_current_addr_sk#140, d_year#144, d_year#146, p_promo_sk#151] (155) ReusedExchange [Reuses operator id: 77] -Output [2]: [hd_demo_sk#150, hd_income_band_sk#151] +Output [2]: [hd_demo_sk#152, hd_income_band_sk#153] (156) BroadcastHashJoin [codegen id : 51] Left keys [1]: [ss_hdemo_sk#109] -Right keys [1]: [hd_demo_sk#150] +Right keys [1]: [hd_demo_sk#152] Join type: Inner Join condition: None (157) Project [codegen id : 51] -Output [13]: [ss_item_sk#106, ss_addr_sk#110, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_hdemo_sk#137, c_current_addr_sk#138, d_year#142, d_year#144, hd_income_band_sk#151] -Input [15]: [ss_item_sk#106, ss_hdemo_sk#109, ss_addr_sk#110, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_hdemo_sk#137, c_current_addr_sk#138, d_year#142, d_year#144, hd_demo_sk#150, hd_income_band_sk#151] +Output [13]: [ss_item_sk#106, ss_addr_sk#110, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_hdemo_sk#139, c_current_addr_sk#140, d_year#144, d_year#146, hd_income_band_sk#153] +Input [15]: [ss_item_sk#106, ss_hdemo_sk#109, ss_addr_sk#110, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_hdemo_sk#139, c_current_addr_sk#140, d_year#144, d_year#146, hd_demo_sk#152, hd_income_band_sk#153] (158) ReusedExchange [Reuses operator id: 77] -Output [2]: [hd_demo_sk#152, hd_income_band_sk#153] +Output [2]: [hd_demo_sk#154, hd_income_band_sk#155] (159) BroadcastHashJoin [codegen id : 51] -Left keys [1]: [c_current_hdemo_sk#137] -Right keys [1]: [hd_demo_sk#152] +Left keys [1]: [c_current_hdemo_sk#139] +Right keys [1]: [hd_demo_sk#154] Join type: Inner Join condition: None (160) Project [codegen id : 51] -Output [13]: [ss_item_sk#106, ss_addr_sk#110, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_addr_sk#138, d_year#142, d_year#144, hd_income_band_sk#151, hd_income_band_sk#153] -Input [15]: [ss_item_sk#106, ss_addr_sk#110, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_hdemo_sk#137, c_current_addr_sk#138, d_year#142, d_year#144, hd_income_band_sk#151, hd_demo_sk#152, hd_income_band_sk#153] +Output [13]: [ss_item_sk#106, ss_addr_sk#110, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_addr_sk#140, d_year#144, d_year#146, hd_income_band_sk#153, hd_income_band_sk#155] +Input [15]: [ss_item_sk#106, ss_addr_sk#110, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_hdemo_sk#139, c_current_addr_sk#140, d_year#144, d_year#146, hd_income_band_sk#153, hd_demo_sk#154, hd_income_band_sk#155] (161) ReusedExchange [Reuses operator id: 86] -Output [5]: [ca_address_sk#154, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158] +Output [5]: [ca_address_sk#156, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160] (162) BroadcastHashJoin [codegen id : 51] Left keys [1]: [ss_addr_sk#110] -Right keys [1]: [ca_address_sk#154] +Right keys [1]: [ca_address_sk#156] Join type: Inner Join condition: None (163) Project [codegen id : 51] -Output [16]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_addr_sk#138, d_year#142, d_year#144, hd_income_band_sk#151, hd_income_band_sk#153, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158] -Input [18]: [ss_item_sk#106, ss_addr_sk#110, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_addr_sk#138, d_year#142, d_year#144, hd_income_band_sk#151, hd_income_band_sk#153, ca_address_sk#154, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158] +Output [16]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_addr_sk#140, d_year#144, d_year#146, hd_income_band_sk#153, hd_income_band_sk#155, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160] +Input [18]: [ss_item_sk#106, ss_addr_sk#110, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_addr_sk#140, d_year#144, d_year#146, hd_income_band_sk#153, hd_income_band_sk#155, ca_address_sk#156, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160] (164) ReusedExchange [Reuses operator id: 86] -Output [5]: [ca_address_sk#159, ca_street_number#160, ca_street_name#161, ca_city#162, ca_zip#163] +Output [5]: [ca_address_sk#161, ca_street_number#162, ca_street_name#163, ca_city#164, ca_zip#165] (165) BroadcastHashJoin [codegen id : 51] -Left keys [1]: [c_current_addr_sk#138] -Right keys [1]: [ca_address_sk#159] +Left keys [1]: [c_current_addr_sk#140] +Right keys [1]: [ca_address_sk#161] Join type: Inner Join condition: None (166) Project [codegen id : 51] -Output [19]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, d_year#142, d_year#144, hd_income_band_sk#151, hd_income_band_sk#153, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158, ca_street_number#160, ca_street_name#161, ca_city#162, ca_zip#163] -Input [21]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, c_current_addr_sk#138, d_year#142, d_year#144, hd_income_band_sk#151, hd_income_band_sk#153, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158, ca_address_sk#159, ca_street_number#160, ca_street_name#161, ca_city#162, ca_zip#163] +Output [19]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, d_year#144, d_year#146, hd_income_band_sk#153, hd_income_band_sk#155, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160, ca_street_number#162, ca_street_name#163, ca_city#164, ca_zip#165] +Input [21]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, c_current_addr_sk#140, d_year#144, d_year#146, hd_income_band_sk#153, hd_income_band_sk#155, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160, ca_address_sk#161, ca_street_number#162, ca_street_name#163, ca_city#164, ca_zip#165] (167) ReusedExchange [Reuses operator id: 95] -Output [1]: [ib_income_band_sk#164] +Output [1]: [ib_income_band_sk#166] (168) BroadcastHashJoin [codegen id : 51] -Left keys [1]: [hd_income_band_sk#151] -Right keys [1]: [ib_income_band_sk#164] +Left keys [1]: [hd_income_band_sk#153] +Right keys [1]: [ib_income_band_sk#166] Join type: Inner Join condition: None (169) Project [codegen id : 51] -Output [18]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, d_year#142, d_year#144, hd_income_band_sk#153, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158, ca_street_number#160, ca_street_name#161, ca_city#162, ca_zip#163] -Input [20]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, d_year#142, d_year#144, hd_income_band_sk#151, hd_income_band_sk#153, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158, ca_street_number#160, ca_street_name#161, ca_city#162, ca_zip#163, ib_income_band_sk#164] +Output [18]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, d_year#144, d_year#146, hd_income_band_sk#155, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160, ca_street_number#162, ca_street_name#163, ca_city#164, ca_zip#165] +Input [20]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, d_year#144, d_year#146, hd_income_band_sk#153, hd_income_band_sk#155, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160, ca_street_number#162, ca_street_name#163, ca_city#164, ca_zip#165, ib_income_band_sk#166] (170) ReusedExchange [Reuses operator id: 95] -Output [1]: [ib_income_band_sk#165] +Output [1]: [ib_income_band_sk#167] (171) BroadcastHashJoin [codegen id : 51] -Left keys [1]: [hd_income_band_sk#153] -Right keys [1]: [ib_income_band_sk#165] +Left keys [1]: [hd_income_band_sk#155] +Right keys [1]: [ib_income_band_sk#167] Join type: Inner Join condition: None (172) Project [codegen id : 51] -Output [17]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, d_year#142, d_year#144, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158, ca_street_number#160, ca_street_name#161, ca_city#162, ca_zip#163] -Input [19]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, d_year#142, d_year#144, hd_income_band_sk#153, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158, ca_street_number#160, ca_street_name#161, ca_city#162, ca_zip#163, ib_income_band_sk#165] +Output [17]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, d_year#144, d_year#146, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160, ca_street_number#162, ca_street_name#163, ca_city#164, ca_zip#165] +Input [19]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, d_year#144, d_year#146, hd_income_band_sk#155, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160, ca_street_number#162, ca_street_name#163, ca_city#164, ca_zip#165, ib_income_band_sk#167] (173) ReusedExchange [Reuses operator id: 105] -Output [2]: [i_item_sk#166, i_product_name#167] +Output [2]: [i_item_sk#168, i_product_name#169] (174) BroadcastHashJoin [codegen id : 51] Left keys [1]: [ss_item_sk#106] -Right keys [1]: [i_item_sk#166] +Right keys [1]: [i_item_sk#168] Join type: Inner Join condition: None (175) Project [codegen id : 51] -Output [18]: [ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, d_year#142, d_year#144, s_store_name#133, s_zip#134, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158, ca_street_number#160, ca_street_name#161, ca_city#162, ca_zip#163, i_item_sk#166, i_product_name#167] -Input [19]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, s_store_name#133, s_zip#134, d_year#142, d_year#144, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158, ca_street_number#160, ca_street_name#161, ca_city#162, ca_zip#163, i_item_sk#166, i_product_name#167] +Output [18]: [ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, d_year#144, d_year#146, s_store_name#135, s_zip#136, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160, ca_street_number#162, ca_street_name#163, ca_city#164, ca_zip#165, i_item_sk#168, i_product_name#169] +Input [19]: [ss_item_sk#106, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, s_store_name#135, s_zip#136, d_year#144, d_year#146, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160, ca_street_number#162, ca_street_name#163, ca_city#164, ca_zip#165, i_item_sk#168, i_product_name#169] (176) HashAggregate [codegen id : 51] -Input [18]: [ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#131, d_year#142, d_year#144, s_store_name#133, s_zip#134, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158, ca_street_number#160, ca_street_name#161, ca_city#162, ca_zip#163, i_item_sk#166, i_product_name#167] -Keys [15]: [i_product_name#167, i_item_sk#166, s_store_name#133, s_zip#134, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158, ca_street_number#160, ca_street_name#161, ca_city#162, ca_zip#163, d_year#131, d_year#142, d_year#144] +Input [18]: [ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, d_year#133, d_year#144, d_year#146, s_store_name#135, s_zip#136, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160, ca_street_number#162, ca_street_name#163, ca_city#164, ca_zip#165, i_item_sk#168, i_product_name#169] +Keys [15]: [i_product_name#169, i_item_sk#168, s_store_name#135, s_zip#136, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160, ca_street_number#162, ca_street_name#163, ca_city#164, ca_zip#165, d_year#133, d_year#144, d_year#146] Functions [4]: [partial_count(1), partial_sum(UnscaledValue(ss_wholesale_cost#114)), partial_sum(UnscaledValue(ss_list_price#115)), partial_sum(UnscaledValue(ss_coupon_amt#116))] -Aggregate Attributes [4]: [count#77, sum#168, sum#169, sum#170] -Results [19]: [i_product_name#167, i_item_sk#166, s_store_name#133, s_zip#134, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158, ca_street_number#160, ca_street_name#161, ca_city#162, ca_zip#163, d_year#131, d_year#142, d_year#144, count#81, sum#171, sum#172, sum#173] +Aggregate Attributes [4]: [count#77, sum#170, sum#171, sum#172] +Results [19]: [i_product_name#169, i_item_sk#168, s_store_name#135, s_zip#136, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160, ca_street_number#162, ca_street_name#163, ca_city#164, ca_zip#165, d_year#133, d_year#144, d_year#146, count#81, sum#173, sum#174, sum#175] (177) HashAggregate [codegen id : 51] -Input [19]: [i_product_name#167, i_item_sk#166, s_store_name#133, s_zip#134, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158, ca_street_number#160, ca_street_name#161, ca_city#162, ca_zip#163, d_year#131, d_year#142, d_year#144, count#81, sum#171, sum#172, sum#173] -Keys [15]: [i_product_name#167, i_item_sk#166, s_store_name#133, s_zip#134, ca_street_number#155, ca_street_name#156, ca_city#157, ca_zip#158, ca_street_number#160, ca_street_name#161, ca_city#162, ca_zip#163, d_year#131, d_year#142, d_year#144] +Input [19]: [i_product_name#169, i_item_sk#168, s_store_name#135, s_zip#136, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160, ca_street_number#162, ca_street_name#163, ca_city#164, ca_zip#165, d_year#133, d_year#144, d_year#146, count#81, sum#173, sum#174, sum#175] +Keys [15]: [i_product_name#169, i_item_sk#168, s_store_name#135, s_zip#136, ca_street_number#157, ca_street_name#158, ca_city#159, ca_zip#160, ca_street_number#162, ca_street_name#163, ca_city#164, ca_zip#165, d_year#133, d_year#144, d_year#146] Functions [4]: [count(1), sum(UnscaledValue(ss_wholesale_cost#114)), sum(UnscaledValue(ss_list_price#115)), sum(UnscaledValue(ss_coupon_amt#116))] Aggregate Attributes [4]: [count(1)#85, sum(UnscaledValue(ss_wholesale_cost#114))#86, sum(UnscaledValue(ss_list_price#115))#87, sum(UnscaledValue(ss_coupon_amt#116))#88] -Results [8]: [i_item_sk#166 AS item_sk#174, s_store_name#133 AS store_name#175, s_zip#134 AS store_zip#176, d_year#131 AS syear#177, count(1)#85 AS cnt#178, MakeDecimal(sum(UnscaledValue(ss_wholesale_cost#114))#86,17,2) AS s1#179, MakeDecimal(sum(UnscaledValue(ss_list_price#115))#87,17,2) AS s2#180, MakeDecimal(sum(UnscaledValue(ss_coupon_amt#116))#88,17,2) AS s3#181] +Results [8]: [i_item_sk#168 AS item_sk#176, s_store_name#135 AS store_name#177, s_zip#136 AS store_zip#178, d_year#133 AS syear#179, count(1)#85 AS cnt#180, MakeDecimal(sum(UnscaledValue(ss_wholesale_cost#114))#86,17,2) AS s1#181, MakeDecimal(sum(UnscaledValue(ss_list_price#115))#87,17,2) AS s2#182, MakeDecimal(sum(UnscaledValue(ss_coupon_amt#116))#88,17,2) AS s3#183] (178) Exchange -Input [8]: [item_sk#174, store_name#175, store_zip#176, syear#177, cnt#178, s1#179, s2#180, s3#181] -Arguments: hashpartitioning(item_sk#174, store_name#175, store_zip#176, 5), ENSURE_REQUIREMENTS, [plan_id=18] +Input [8]: [item_sk#176, store_name#177, store_zip#178, syear#179, cnt#180, s1#181, s2#182, s3#183] +Arguments: hashpartitioning(item_sk#176, store_name#177, store_zip#178, 5), ENSURE_REQUIREMENTS, [plan_id=18] (179) Sort [codegen id : 52] -Input [8]: [item_sk#174, store_name#175, store_zip#176, syear#177, cnt#178, s1#179, s2#180, s3#181] -Arguments: [item_sk#174 ASC NULLS FIRST, store_name#175 ASC NULLS FIRST, store_zip#176 ASC NULLS FIRST], false, 0 +Input [8]: [item_sk#176, store_name#177, store_zip#178, syear#179, cnt#180, s1#181, s2#182, s3#183] +Arguments: [item_sk#176 ASC NULLS FIRST, store_name#177 ASC NULLS FIRST, store_zip#178 ASC NULLS FIRST], false, 0 (180) SortMergeJoin [codegen id : 53] Left keys [3]: [item_sk#90, store_name#91, store_zip#92] -Right keys [3]: [item_sk#174, store_name#175, store_zip#176] +Right keys [3]: [item_sk#176, store_name#177, store_zip#178] Join type: Inner -Join condition: (cnt#178 <= cnt#102) +Join condition: (cnt#180 <= cnt#102) (181) Project [codegen id : 53] -Output [21]: [product_name#89, store_name#91, store_zip#92, b_street_number#93, b_streen_name#94, b_city#95, b_zip#96, c_street_number#97, c_street_name#98, c_city#99, c_zip#100, syear#101, cnt#102, s1#103, s2#104, s3#105, s1#179, s2#180, s3#181, syear#177, cnt#178] -Input [25]: [product_name#89, item_sk#90, store_name#91, store_zip#92, b_street_number#93, b_streen_name#94, b_city#95, b_zip#96, c_street_number#97, c_street_name#98, c_city#99, c_zip#100, syear#101, cnt#102, s1#103, s2#104, s3#105, item_sk#174, store_name#175, store_zip#176, syear#177, cnt#178, s1#179, s2#180, s3#181] +Output [21]: [product_name#89, store_name#91, store_zip#92, b_street_number#93, b_streen_name#94, b_city#95, b_zip#96, c_street_number#97, c_street_name#98, c_city#99, c_zip#100, syear#101, cnt#102, s1#103, s2#104, s3#105, s1#181, s2#182, s3#183, syear#179, cnt#180] +Input [25]: [product_name#89, item_sk#90, store_name#91, store_zip#92, b_street_number#93, b_streen_name#94, b_city#95, b_zip#96, c_street_number#97, c_street_name#98, c_city#99, c_zip#100, syear#101, cnt#102, s1#103, s2#104, s3#105, item_sk#176, store_name#177, store_zip#178, syear#179, cnt#180, s1#181, s2#182, s3#183] (182) Exchange -Input [21]: [product_name#89, store_name#91, store_zip#92, b_street_number#93, b_streen_name#94, b_city#95, b_zip#96, c_street_number#97, c_street_name#98, c_city#99, c_zip#100, syear#101, cnt#102, s1#103, s2#104, s3#105, s1#179, s2#180, s3#181, syear#177, cnt#178] -Arguments: rangepartitioning(product_name#89 ASC NULLS FIRST, store_name#91 ASC NULLS FIRST, cnt#178 ASC NULLS FIRST, 5), ENSURE_REQUIREMENTS, [plan_id=19] +Input [21]: [product_name#89, store_name#91, store_zip#92, b_street_number#93, b_streen_name#94, b_city#95, b_zip#96, c_street_number#97, c_street_name#98, c_city#99, c_zip#100, syear#101, cnt#102, s1#103, s2#104, s3#105, s1#181, s2#182, s3#183, syear#179, cnt#180] +Arguments: rangepartitioning(product_name#89 ASC NULLS FIRST, store_name#91 ASC NULLS FIRST, cnt#180 ASC NULLS FIRST, 5), ENSURE_REQUIREMENTS, [plan_id=19] (183) Sort [codegen id : 54] -Input [21]: [product_name#89, store_name#91, store_zip#92, b_street_number#93, b_streen_name#94, b_city#95, b_zip#96, c_street_number#97, c_street_name#98, c_city#99, c_zip#100, syear#101, cnt#102, s1#103, s2#104, s3#105, s1#179, s2#180, s3#181, syear#177, cnt#178] -Arguments: [product_name#89 ASC NULLS FIRST, store_name#91 ASC NULLS FIRST, cnt#178 ASC NULLS FIRST], true, 0 +Input [21]: [product_name#89, store_name#91, store_zip#92, b_street_number#93, b_streen_name#94, b_city#95, b_zip#96, c_street_number#97, c_street_name#98, c_city#99, c_zip#100, syear#101, cnt#102, s1#103, s2#104, s3#105, s1#181, s2#182, s3#183, syear#179, cnt#180] +Arguments: [product_name#89 ASC NULLS FIRST, store_name#91 ASC NULLS FIRST, cnt#180 ASC NULLS FIRST], true, 0 ===== Subqueries ===== @@ -1054,21 +1054,21 @@ BroadcastExchange (191) (188) Scan parquet spark_catalog.default.date_dim -Output [2]: [d_date_sk#130, d_year#131] +Output [2]: [d_date_sk#132, d_year#133] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), EqualTo(d_year,2000), IsNotNull(d_date_sk)] ReadSchema: struct (189) ColumnarToRow [codegen id : 1] -Input [2]: [d_date_sk#130, d_year#131] +Input [2]: [d_date_sk#132, d_year#133] (190) Filter [codegen id : 1] -Input [2]: [d_date_sk#130, d_year#131] -Condition : ((isnotnull(d_year#131) AND (d_year#131 = 2000)) AND isnotnull(d_date_sk#130)) +Input [2]: [d_date_sk#132, d_year#133] +Condition : ((isnotnull(d_year#133) AND (d_year#133 = 2000)) AND isnotnull(d_date_sk#132)) (191) BroadcastExchange -Input [2]: [d_date_sk#130, d_year#131] +Input [2]: [d_date_sk#132, d_year#133] Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [plan_id=21] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14.sf100/explain.txt index 1440326b862e9..fafd7fd75cbd7 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14.sf100/explain.txt @@ -648,7 +648,7 @@ BroadcastExchange (114) Output [2]: [d_date_sk#36, d_week_seq#100] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_week_seq), IsNotNull(d_date_sk)] +PushedFilters: [IsNotNull(d_week_seq), EqualTo(d_week_seq,ScalarSubquery#101), IsNotNull(d_date_sk)] ReadSchema: struct (111) ColumnarToRow [codegen id : 1] @@ -741,7 +741,7 @@ BroadcastExchange (128) Output [2]: [d_date_sk#60, d_week_seq#108] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_week_seq), IsNotNull(d_date_sk)] +PushedFilters: [IsNotNull(d_week_seq), EqualTo(d_week_seq,ScalarSubquery#109), IsNotNull(d_date_sk)] ReadSchema: struct (125) ColumnarToRow [codegen id : 1] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14/explain.txt index 1e4ca929b9690..4d69899b3b17a 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q14/explain.txt @@ -618,7 +618,7 @@ BroadcastExchange (108) Output [2]: [d_date_sk#40, d_week_seq#100] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_week_seq), IsNotNull(d_date_sk)] +PushedFilters: [IsNotNull(d_week_seq), EqualTo(d_week_seq,ScalarSubquery#101), IsNotNull(d_date_sk)] ReadSchema: struct (105) ColumnarToRow [codegen id : 1] @@ -711,7 +711,7 @@ BroadcastExchange (122) Output [2]: [d_date_sk#64, d_week_seq#108] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_week_seq), IsNotNull(d_date_sk)] +PushedFilters: [IsNotNull(d_week_seq), EqualTo(d_week_seq,ScalarSubquery#109), IsNotNull(d_date_sk)] ReadSchema: struct (119) ColumnarToRow [codegen id : 1] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q6.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q6.sf100/explain.txt index 55bed0dade77f..afdfc51a17dd4 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q6.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q6.sf100/explain.txt @@ -272,7 +272,7 @@ BroadcastExchange (50) Output [2]: [d_date_sk#16, d_month_seq#26] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_month_seq), IsNotNull(d_date_sk)] +PushedFilters: [IsNotNull(d_month_seq), EqualTo(d_month_seq,ScalarSubquery#27), IsNotNull(d_date_sk)] ReadSchema: struct (47) ColumnarToRow [codegen id : 1] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q6/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q6/explain.txt index 6713acc975445..a2638dac56456 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q6/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q6/explain.txt @@ -242,7 +242,7 @@ BroadcastExchange (44) Output [2]: [d_date_sk#9, d_month_seq#26] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] -PushedFilters: [IsNotNull(d_month_seq), IsNotNull(d_date_sk)] +PushedFilters: [IsNotNull(d_month_seq), EqualTo(d_month_seq,ScalarSubquery#27), IsNotNull(d_date_sk)] ReadSchema: struct (41) ColumnarToRow [codegen id : 1] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index 6033b9fee848e..a657c6212aa07 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -273,4 +273,25 @@ class DatasetCacheSuite extends QueryTest } } } + + test("SPARK-44653: non-trivial DataFrame unions should not break caching") { + val df1 = Seq(1 -> 1).toDF("i", "j") + val df2 = Seq(2 -> 2).toDF("i", "j") + val df3 = Seq(3 -> 3).toDF("i", "j") + + withClue("positive") { + val unionDf = df1.union(df2).select($"i") + unionDf.cache() + val finalDf = unionDf.union(df3.select($"i")) + assert(finalDf.queryExecution.executedPlan.exists(_.isInstanceOf[InMemoryTableScanExec])) + } + + withClue("negative") { + val unionDf = df1.union(df2) + unionDf.cache() + val finalDf = unionDf.union(df3) + // It's by design to break caching here. + assert(!finalDf.queryExecution.executedPlan.exists(_.isInstanceOf[InMemoryTableScanExec])) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 7f358723eeb8f..14f1fb27906a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -1709,4 +1709,24 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan checkAnswer(sql(query), expected) } } + + test("SPARK-44132: FULL OUTER JOIN by streamed column name fails with NPE") { + val dsA = Seq((1, "a")).toDF("id", "c1") + val dsB = Seq((2, "b")).toDF("id", "c2") + val dsC = Seq((3, "c")).toDF("id", "c3") + val joined = dsA.join(dsB, Stream("id"), "full_outer").join(dsC, Stream("id"), "full_outer") + + val expected = Seq(Row(1, "a", null, null), Row(2, null, "b", null), Row(3, null, null, "c")) + + checkAnswer(joined, expected) + } + + test("SPARK-44132: FULL OUTER JOIN by streamed column name fails with invalid access") { + val ds = Seq((1, "a")).toDF("id", "c1") + val joined = ds.join(ds, Stream("id"), "full_outer").join(ds, Stream("id"), "full_outer") + + val expected = Seq(Row(1, "a", "a", "a")) + + checkAnswer(joined, expected) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 880c30ba9f98d..8461f528277c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -1039,4 +1039,60 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } } + + test("SPARK-44641: duplicated records when SPJ is not triggered") { + val items_partitions = Array(bucket(8, "id")) + createTable(items, items_schema, items_partitions) + sql(s""" + INSERT INTO testcat.ns.$items VALUES + (1, 'aa', 40.0, cast('2020-01-01' as timestamp)), + (1, 'aa', 41.0, cast('2020-01-15' as timestamp)), + (2, 'bb', 10.0, cast('2020-01-01' as timestamp)), + (2, 'bb', 10.5, cast('2020-01-01' as timestamp)), + (3, 'cc', 15.5, cast('2020-02-01' as timestamp))""") + + val purchases_partitions = Array(bucket(8, "item_id")) + createTable(purchases, purchases_schema, purchases_partitions) + sql(s"""INSERT INTO testcat.ns.$purchases VALUES + (1, 42.0, cast('2020-01-01' as timestamp)), + (1, 44.0, cast('2020-01-15' as timestamp)), + (1, 45.0, cast('2020-01-15' as timestamp)), + (2, 11.0, cast('2020-01-01' as timestamp)), + (3, 19.5, cast('2020-02-01' as timestamp))""") + + Seq(true, false).foreach { pushDownValues => + Seq(true, false).foreach { partiallyClusteredEnabled => + withSQLConf( + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClusteredEnabled.toString) { + + // join keys are not the same as the partition keys, therefore SPJ is not triggered. + val df = sql( + s""" + SELECT id, name, i.price as purchase_price, p.item_id, p.price as sale_price + FROM testcat.ns.$items i JOIN testcat.ns.$purchases p + ON i.arrive_time = p.time ORDER BY id, purchase_price, p.item_id, sale_price + """) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.nonEmpty, "shuffle should exist when SPJ is not used") + + checkAnswer(df, + Seq( + Row(1, "aa", 40.0, 1, 42.0), + Row(1, "aa", 40.0, 2, 11.0), + Row(1, "aa", 41.0, 1, 44.0), + Row(1, "aa", 41.0, 1, 45.0), + Row(2, "bb", 10.0, 1, 42.0), + Row(2, "bb", 10.0, 2, 11.0), + Row(2, "bb", 10.5, 1, 42.0), + Row(2, "bb", 10.5, 2, 11.0), + Row(3, "cc", 15.5, 3, 19.5) + ) + ) + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala index e70d04b7b5a6f..fb10e90b6ccea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.{Grouping, Literal, RowNumber} import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.expressions.objects.InitializeJavaBean +import org.apache.spark.sql.catalyst.rules.RuleIdCollection import org.apache.spark.sql.catalyst.util.BadRecordException import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, JDBCOptions} import org.apache.spark.sql.execution.datasources.jdbc.connection.ConnectionProvider @@ -499,6 +500,16 @@ class QueryExecutionErrorsSuite } } + test("SPARK-42330: rule id not found") { + checkError( + exception = intercept[SparkException] { + RuleIdCollection.getRuleId("incorrect") + }, + errorClass = "RULE_ID_NOT_FOUND", + parameters = Map("ruleName" -> "incorrect") + ) + } + test("CANNOT_RESTORE_PERMISSIONS_FOR_PATH: can't set permission") { withTable("t") { withSQLConf( @@ -714,6 +725,23 @@ class QueryExecutionErrorsSuite } } + test("CANNOT_PARSE_STRING_AS_DATATYPE: parse string as float use from_json") { + val jsonStr = """{"a": "str"}""" + checkError( + exception = intercept[SparkRuntimeException] { + sql(s"""SELECT from_json('$jsonStr', 'a FLOAT', map('mode','FAILFAST'))""").collect() + }, + errorClass = "MALFORMED_RECORD_IN_PARSING.CANNOT_PARSE_STRING_AS_DATATYPE", + parameters = Map( + "badRecord" -> jsonStr, + "failFastMode" -> "FAILFAST", + "fieldName" -> "`a`", + "fieldValue" -> "'str'", + "inputType" -> "StringType", + "targetType" -> "FloatType"), + sqlState = "22023") + } + test("BINARY_ARITHMETIC_OVERFLOW: byte plus byte result overflow") { withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { checkError( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala index a35fb5f627145..2b9ec97bace1e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala @@ -324,4 +324,18 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { DataSourceStrategy.translateFilter(catalystFilter, true) } } + + test("SPARK-41636: selectFilters returns predicates in deterministic order") { + + val predicates = Seq(EqualTo($"id", 1), EqualTo($"id", 2), + EqualTo($"id", 3), EqualTo($"id", 4), EqualTo($"id", 5), EqualTo($"id", 6)) + + val (unhandledPredicates, pushedFilters, handledFilters) = + DataSourceStrategy.selectFilters(FakeRelation(), predicates) + assert(unhandledPredicates.equals(predicates)) + assert(pushedFilters.zipWithIndex.forall { case (f, i) => + f.equals(sources.EqualTo("id", i + 1)) + }) + assert(handledFilters.isEmpty) + } }