diff --git a/python/pyspark/sql/worker/analyze_udtf.py b/python/pyspark/sql/worker/analyze_udtf.py index a6aa381eb14a5..9e84b880fc968 100644 --- a/python/pyspark/sql/worker/analyze_udtf.py +++ b/python/pyspark/sql/worker/analyze_udtf.py @@ -98,14 +98,14 @@ def main(infile: IO, outfile: IO) -> None: """ Runs the Python UDTF's `analyze` static method. - This process will be invoked from `UserDefinedPythonTableFunction.analyzeInPython` in JVM - and receive the Python UDTF and its arguments for the `analyze` static method, + This process will be invoked from `UserDefinedPythonTableFunctionAnalyzeRunner.runInPython` + in JVM and receive the Python UDTF and its arguments for the `analyze` static method, and call the `analyze` static method, and send back a AnalyzeResult as a result of the method. """ try: check_python_version(infile) - memory_limit_mb = int(os.environ.get("PYSPARK_UDTF_ANALYZER_MEMORY_MB", "-1")) + memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1")) setup_memory_limits(memory_limit_mb) setup_spark_files(infile) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 12ec9e911d31d..000694f6f1bba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3008,14 +3008,14 @@ object SQLConf { .booleanConf .createWithDefault(false) - val PYTHON_TABLE_UDF_ANALYZER_MEMORY = - buildConf("spark.sql.analyzer.pythonUDTF.analyzeInPython.memory") - .doc("The amount of memory to be allocated to PySpark for Python UDTF analyzer, in MiB " + - "unless otherwise specified. If set, PySpark memory for Python UDTF analyzer will be " + - "limited to this amount. If not set, Spark will not limit Python's " + - "memory use and it is up to the application to avoid exceeding the overhead memory space " + - "shared with other non-JVM processes.\nNote: Windows does not support resource limiting " + - "and actual resource is not limited on MacOS.") + val PYTHON_PLANNER_EXEC_MEMORY = + buildConf("spark.sql.planner.pythonExecution.memory") + .doc("Specifies the memory allocation for executing Python code in Spark driver, in MiB. " + + "When set, it caps the memory for Python execution to the specified amount. " + + "If not set, Spark will not limit Python's memory usage and it is up to the application " + + "to avoid exceeding the overhead memory space shared with other non-JVM processes.\n" + + "Note: Windows does not support resource limiting and actual resource is not limited " + + "on MacOS.") .version("4.0.0") .bytesConf(ByteUnit.MiB) .createOptional @@ -5157,7 +5157,7 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def pysparkWorkerPythonExecutable: Option[String] = getConf(SQLConf.PYSPARK_WORKER_PYTHON_EXECUTABLE) - def pythonUDTFAnalyzerMemory: Option[Long] = getConf(PYTHON_TABLE_UDF_ANALYZER_MEMORY) + def pythonPlannerExecMemory: Option[Long] = getConf(PYTHON_PLANNER_EXEC_MEMORY) def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala new file mode 100644 index 0000000000000..183b96bb982c7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException, InputStream} +import java.nio.ByteBuffer +import java.nio.channels.SelectionKey +import java.util.HashMap + +import scala.jdk.CollectionConverters._ + +import net.razorvine.pickle.Pickler + +import org.apache.spark.{JobArtifactSet, SparkEnv, SparkException} +import org.apache.spark.api.python.{PythonFunction, PythonWorker, PythonWorkerUtils, SpecialLengths} +import org.apache.spark.internal.config.BUFFER_SIZE +import org.apache.spark.internal.config.Python._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.DirectByteBufferOutputStream + +/** + * A helper class to run Python functions in Spark driver. + */ +abstract class PythonPlannerRunner[T](func: PythonFunction) { + + protected val workerModule: String + + protected def writeToPython(dataOut: DataOutputStream, pickler: Pickler): Unit + + protected def receiveFromPython(dataIn: DataInputStream): T + + def runInPython(): T = { + val env = SparkEnv.get + val bufferSize: Int = env.conf.get(BUFFER_SIZE) + val authSocketTimeout = env.conf.get(PYTHON_AUTH_SOCKET_TIMEOUT) + val reuseWorker = env.conf.get(PYTHON_WORKER_REUSE) + val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") + val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback + val workerMemoryMb = SQLConf.get.pythonPlannerExecMemory + + val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + + val envVars = new HashMap[String, String](func.envVars) + val pythonExec = func.pythonExec + val pythonVer = func.pythonVer + val pythonIncludes = func.pythonIncludes.asScala.toSet + val broadcastVars = func.broadcastVars.asScala.toSeq + val maybeAccumulator = Option(func.accumulator).map(_.copyAndReset()) + + envVars.put("SPARK_LOCAL_DIRS", localdir) + if (reuseWorker) { + envVars.put("SPARK_REUSE_WORKER", "1") + } + if (simplifiedTraceback) { + envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1") + } + workerMemoryMb.foreach { memoryMb => + envVars.put("PYSPARK_PLANNER_MEMORY_MB", memoryMb.toString) + } + envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString) + envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString) + + envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default")) + + EvaluatePython.registerPicklers() + val pickler = new Pickler(/* useMemo = */ true, + /* valueCompare = */ false) + + val (worker: PythonWorker, _) = + env.createPythonWorker(pythonExec, workerModule, envVars.asScala.toMap) + var releasedOrClosed = false + val bufferStream = new DirectByteBufferOutputStream() + try { + val dataOut = new DataOutputStream(new BufferedOutputStream(bufferStream, bufferSize)) + + PythonWorkerUtils.writePythonVersion(pythonVer, dataOut) + PythonWorkerUtils.writeSparkFiles(jobArtifactUUID, pythonIncludes, dataOut) + PythonWorkerUtils.writeBroadcasts(broadcastVars, worker, env, dataOut) + + writeToPython(dataOut, pickler) + + dataOut.writeInt(SpecialLengths.END_OF_STREAM) + dataOut.flush() + + val dataIn = new DataInputStream(new BufferedInputStream( + new WorkerInputStream(worker, bufferStream.toByteBuffer), bufferSize)) + + val res = receiveFromPython(dataIn) + + PythonWorkerUtils.receiveAccumulatorUpdates(maybeAccumulator, dataIn) + Option(func.accumulator).foreach(_.merge(maybeAccumulator.get)) + + dataIn.readInt() match { + case SpecialLengths.END_OF_STREAM if reuseWorker => + env.releasePythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker) + case _ => + env.destroyPythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker) + } + releasedOrClosed = true + + res + } catch { + case eof: EOFException => + throw new SparkException("Python worker exited unexpectedly (crashed)", eof) + } finally { + try { + bufferStream.close() + } finally { + if (!releasedOrClosed) { + // An error happened. Force to close the worker. + env.destroyPythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker) + } + } + } + } + + /** + * A wrapper of the non-blocking IO to write to/read from the worker. + * + * Since we use non-blocking IO to communicate with workers; see SPARK-44705, + * a wrapper is needed to do IO with the worker. + * This is a port and simplified version of `PythonRunner.ReaderInputStream`, + * and only supports to write all at once and then read all. + */ + private class WorkerInputStream(worker: PythonWorker, buffer: ByteBuffer) extends InputStream { + + private[this] val temp = new Array[Byte](1) + + override def read(): Int = { + val n = read(temp) + if (n <= 0) { + -1 + } else { + // Signed byte to unsigned integer + temp(0) & 0xff + } + } + + override def read(b: Array[Byte], off: Int, len: Int): Int = { + val buf = ByteBuffer.wrap(b, off, len) + var n = 0 + while (n == 0) { + worker.selector.select() + if (worker.selectionKey.isReadable) { + n = worker.channel.read(buf) + } + if (worker.selectionKey.isWritable) { + var acceptsInput = true + while (acceptsInput && buffer.hasRemaining) { + val n = worker.channel.write(buffer) + acceptsInput = n > 0 + } + if (!buffer.hasRemaining) { + // We no longer have any data to write to the socket. + worker.selectionKey.interestOps(SelectionKey.OP_READ) + } + } + } + n + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index d8d3cc9b7fc43..f2f952f079e2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -17,28 +17,19 @@ package org.apache.spark.sql.execution.python -import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException, InputStream} -import java.nio.ByteBuffer -import java.nio.channels.SelectionKey -import java.util.HashMap +import java.io.{DataInputStream, DataOutputStream} import scala.collection.mutable.ArrayBuffer -import scala.jdk.CollectionConverters._ import net.razorvine.pickle.Pickler -import org.apache.spark.{JobArtifactSet, SparkEnv, SparkException} -import org.apache.spark.api.python.{PythonEvalType, PythonFunction, PythonWorker, PythonWorkerUtils, SpecialLengths} -import org.apache.spark.internal.config.BUFFER_SIZE -import org.apache.spark.internal.config.Python._ +import org.apache.spark.api.python.{PythonEvalType, PythonFunction, PythonWorkerUtils, SpecialLengths} import org.apache.spark.sql.{Column, DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, FunctionTableSubqueryArgumentExpression, NamedArgumentExpression, NullsFirst, NullsLast, PythonUDAF, PythonUDF, PythonUDTF, PythonUDTFAnalyzeResult, SortOrder, UnresolvedPolymorphicPythonUDTF} import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, NamedParametersSupport, OneRowRelation} import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} -import org.apache.spark.util.DirectByteBufferOutputStream /** * A user-defined Python function. This is used by the Python API. @@ -141,13 +132,17 @@ case class UserDefinedPythonTableFunction( case NamedArgumentExpression(_, _: FunctionTableSubqueryArgumentExpression) => true case _ => false } + val runAnalyzeInPython = (func: PythonFunction, exprs: Seq[Expression]) => { + val runner = new UserDefinedPythonTableFunctionAnalyzeRunner(func, exprs, tableArgs) + runner.runInPython() + } UnresolvedPolymorphicPythonUDTF( name = name, func = func, children = exprs, evalType = pythonEvalType, udfDeterministic = udfDeterministic, - resolveElementMetadata = UserDefinedPythonTableFunction.analyzeInPython(_, _, tableArgs)) + resolveElementMetadata = runAnalyzeInPython) } Generate( udtf, @@ -166,228 +161,106 @@ case class UserDefinedPythonTableFunction( } } -object UserDefinedPythonTableFunction { - - private[this] val workerModule = "pyspark.sql.worker.analyze_udtf" - - /** - * Runs the Python UDTF's `analyze` static method. - * - * When the Python UDTF is defined without a static return type, - * the analyzer will call this while resolving table-valued functions. - * - * This expects the Python UDTF to have `analyze` static method that take arguments: - * - * - The number and order of arguments are the same as the UDTF inputs - * - Each argument is an `AnalyzeArgument`, containing: - * - data_type: DataType - * - value: Any: if the argument is foldable; otherwise None - * - is_table: bool: True if the argument is TABLE - * - * and that return an `AnalyzeResult`. - * - * It serializes/deserializes the data types via JSON, - * and the values for the case the argument is foldable are pickled. - * - * `AnalysisException` with the error class "TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON" - * will be thrown when an exception is raised in Python. - */ - def analyzeInPython( - func: PythonFunction, - exprs: Seq[Expression], - tableArgs: Seq[Boolean]): PythonUDTFAnalyzeResult = { - val env = SparkEnv.get - val bufferSize: Int = env.conf.get(BUFFER_SIZE) - val authSocketTimeout = env.conf.get(PYTHON_AUTH_SOCKET_TIMEOUT) - val reuseWorker = env.conf.get(PYTHON_WORKER_REUSE) - val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") - val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback - val workerMemoryMb = SQLConf.get.pythonUDTFAnalyzerMemory - - val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) - - val envVars = new HashMap[String, String](func.envVars) - val pythonExec = func.pythonExec - val pythonVer = func.pythonVer - val pythonIncludes = func.pythonIncludes.asScala.toSet - val broadcastVars = func.broadcastVars.asScala.toSeq - val maybeAccumulator = Option(func.accumulator).map(_.copyAndReset()) - - envVars.put("SPARK_LOCAL_DIRS", localdir) - if (reuseWorker) { - envVars.put("SPARK_REUSE_WORKER", "1") - } - if (simplifiedTraceback) { - envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1") - } - workerMemoryMb.foreach { memoryMb => - envVars.put("PYSPARK_UDTF_ANALYZER_MEMORY_MB", memoryMb.toString) - } - envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString) - envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString) - - envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default")) - - EvaluatePython.registerPicklers() - val pickler = new Pickler(/* useMemo = */ true, - /* valueCompare = */ false) - - val (worker: PythonWorker, _) = - env.createPythonWorker(pythonExec, workerModule, envVars.asScala.toMap) - var releasedOrClosed = false - val bufferStream = new DirectByteBufferOutputStream() - try { - val dataOut = new DataOutputStream(new BufferedOutputStream(bufferStream, bufferSize)) - - PythonWorkerUtils.writePythonVersion(pythonVer, dataOut) - PythonWorkerUtils.writeSparkFiles(jobArtifactUUID, pythonIncludes, dataOut) - PythonWorkerUtils.writeBroadcasts(broadcastVars, worker, env, dataOut) - - // Send Python UDTF - PythonWorkerUtils.writePythonFunction(func, dataOut) - - // Send arguments - dataOut.writeInt(exprs.length) - exprs.zip(tableArgs).foreach { case (expr, is_table) => - PythonWorkerUtils.writeUTF(expr.dataType.json, dataOut) - if (expr.foldable) { - dataOut.writeBoolean(true) - val obj = pickler.dumps(EvaluatePython.toJava(expr.eval(), expr.dataType)) - PythonWorkerUtils.writeBytes(obj, dataOut) - } else { - dataOut.writeBoolean(false) - } - dataOut.writeBoolean(is_table) - // If the expr is NamedArgumentExpression, send its name. - expr match { - case NamedArgumentExpression(key, _) => - dataOut.writeBoolean(true) - PythonWorkerUtils.writeUTF(key, dataOut) - case _ => - dataOut.writeBoolean(false) - } - } - - dataOut.writeInt(SpecialLengths.END_OF_STREAM) - dataOut.flush() - - val dataIn = new DataInputStream(new BufferedInputStream( - new WorkerInputStream(worker, bufferStream.toByteBuffer), bufferSize)) - - // Receive the schema or an exception raised in Python worker. - val length = dataIn.readInt() - if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) { - val msg = PythonWorkerUtils.readUTF(dataIn) - throw QueryCompilationErrors.tableValuedFunctionFailedToAnalyseInPythonError(msg) - } - - val schema = DataType.fromJson( - PythonWorkerUtils.readUTF(length, dataIn)).asInstanceOf[StructType] - - // Receive the pickled AnalyzeResult buffer, if any. - val pickledAnalyzeResult: Array[Byte] = PythonWorkerUtils.readBytes(dataIn) - - // Receive whether the "with single partition" property is requested. - val withSinglePartition = dataIn.readInt() == 1 - // Receive the list of requested partitioning columns, if any. - val partitionByColumns = ArrayBuffer.empty[Expression] - val numPartitionByColumns = dataIn.readInt() - for (_ <- 0 until numPartitionByColumns) { - val columnName = PythonWorkerUtils.readUTF(dataIn) - partitionByColumns.append(UnresolvedAttribute(columnName)) - } - // Receive the list of requested ordering columns, if any. - val orderBy = ArrayBuffer.empty[SortOrder] - val numOrderByItems = dataIn.readInt() - for (_ <- 0 until numOrderByItems) { - val columnName = PythonWorkerUtils.readUTF(dataIn) - val direction = if (dataIn.readInt() == 1) Ascending else Descending - val overrideNullsFirst = dataIn.readInt() - overrideNullsFirst match { - case 0 => - orderBy.append(SortOrder(UnresolvedAttribute(columnName), direction)) - case 1 => orderBy.append( - SortOrder(UnresolvedAttribute(columnName), direction, NullsFirst, Seq.empty)) - case 2 => orderBy.append( - SortOrder(UnresolvedAttribute(columnName), direction, NullsLast, Seq.empty)) - } +/** + * Runs the Python UDTF's `analyze` static method. + * + * When the Python UDTF is defined without a static return type, + * the analyzer will call this while resolving table-valued functions. + * + * This expects the Python UDTF to have `analyze` static method that take arguments: + * + * - The number and order of arguments are the same as the UDTF inputs + * - Each argument is an `AnalyzeArgument`, containing: + * - data_type: DataType + * - value: Any: if the argument is foldable; otherwise None + * - is_table: bool: True if the argument is TABLE + * + * and that return an `AnalyzeResult`. + * + * It serializes/deserializes the data types via JSON, + * and the values for the case the argument is foldable are pickled. + * + * `AnalysisException` with the error class "TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON" + * will be thrown when an exception is raised in Python. + */ +class UserDefinedPythonTableFunctionAnalyzeRunner( + func: PythonFunction, + exprs: Seq[Expression], + tableArgs: Seq[Boolean]) extends PythonPlannerRunner[PythonUDTFAnalyzeResult](func) { + + override val workerModule = "pyspark.sql.worker.analyze_udtf" + + override protected def writeToPython(dataOut: DataOutputStream, pickler: Pickler): Unit = { + // Send Python UDTF + PythonWorkerUtils.writePythonFunction(func, dataOut) + + // Send arguments + dataOut.writeInt(exprs.length) + exprs.zip(tableArgs).foreach { case (expr, is_table) => + PythonWorkerUtils.writeUTF(expr.dataType.json, dataOut) + if (expr.foldable) { + dataOut.writeBoolean(true) + val obj = pickler.dumps(EvaluatePython.toJava(expr.eval(), expr.dataType)) + PythonWorkerUtils.writeBytes(obj, dataOut) + } else { + dataOut.writeBoolean(false) } - - PythonWorkerUtils.receiveAccumulatorUpdates(maybeAccumulator, dataIn) - Option(func.accumulator).foreach(_.merge(maybeAccumulator.get)) - - dataIn.readInt() match { - case SpecialLengths.END_OF_STREAM if reuseWorker => - env.releasePythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker) + dataOut.writeBoolean(is_table) + // If the expr is NamedArgumentExpression, send its name. + expr match { + case NamedArgumentExpression(key, _) => + dataOut.writeBoolean(true) + PythonWorkerUtils.writeUTF(key, dataOut) case _ => - env.destroyPythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker) - } - releasedOrClosed = true - - PythonUDTFAnalyzeResult( - schema = schema, - withSinglePartition = withSinglePartition, - partitionByExpressions = partitionByColumns.toSeq, - orderByExpressions = orderBy.toSeq, - pickledAnalyzeResult = pickledAnalyzeResult) - } catch { - case eof: EOFException => - throw new SparkException("Python worker exited unexpectedly (crashed)", eof) - } finally { - try { - bufferStream.close() - } finally { - if (!releasedOrClosed) { - // An error happened. Force to close the worker. - env.destroyPythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker) - } + dataOut.writeBoolean(false) } } } - /** - * A wrapper of the non-blocking IO to write to/read from the worker. - * - * Since we use non-blocking IO to communicate with workers; see SPARK-44705, - * a wrapper is needed to do IO with the worker. - * This is a port and simplified version of `PythonRunner.ReaderInputStream`, - * and only supports to write all at once and then read all. - */ - private class WorkerInputStream(worker: PythonWorker, buffer: ByteBuffer) extends InputStream { + override protected def receiveFromPython(dataIn: DataInputStream): PythonUDTFAnalyzeResult = { + // Receive the schema or an exception raised in Python worker. + val length = dataIn.readInt() + if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) { + val msg = PythonWorkerUtils.readUTF(dataIn) + throw QueryCompilationErrors.tableValuedFunctionFailedToAnalyseInPythonError(msg) + } - private[this] val temp = new Array[Byte](1) + val schema = DataType.fromJson( + PythonWorkerUtils.readUTF(length, dataIn)).asInstanceOf[StructType] - override def read(): Int = { - val n = read(temp) - if (n <= 0) { - -1 - } else { - // Signed byte to unsigned integer - temp(0) & 0xff - } - } + // Receive the pickled AnalyzeResult buffer, if any. + val pickledAnalyzeResult: Array[Byte] = PythonWorkerUtils.readBytes(dataIn) - override def read(b: Array[Byte], off: Int, len: Int): Int = { - val buf = ByteBuffer.wrap(b, off, len) - var n = 0 - while (n == 0) { - worker.selector.select() - if (worker.selectionKey.isReadable) { - n = worker.channel.read(buf) - } - if (worker.selectionKey.isWritable) { - var acceptsInput = true - while (acceptsInput && buffer.hasRemaining) { - val n = worker.channel.write(buffer) - acceptsInput = n > 0 - } - if (!buffer.hasRemaining) { - // We no longer have any data to write to the socket. - worker.selectionKey.interestOps(SelectionKey.OP_READ) - } - } + // Receive whether the "with single partition" property is requested. + val withSinglePartition = dataIn.readInt() == 1 + // Receive the list of requested partitioning columns, if any. + val partitionByColumns = ArrayBuffer.empty[Expression] + val numPartitionByColumns = dataIn.readInt() + for (_ <- 0 until numPartitionByColumns) { + val columnName = PythonWorkerUtils.readUTF(dataIn) + partitionByColumns.append(UnresolvedAttribute(columnName)) + } + // Receive the list of requested ordering columns, if any. + val orderBy = ArrayBuffer.empty[SortOrder] + val numOrderByItems = dataIn.readInt() + for (_ <- 0 until numOrderByItems) { + val columnName = PythonWorkerUtils.readUTF(dataIn) + val direction = if (dataIn.readInt() == 1) Ascending else Descending + val overrideNullsFirst = dataIn.readInt() + overrideNullsFirst match { + case 0 => + orderBy.append(SortOrder(UnresolvedAttribute(columnName), direction)) + case 1 => orderBy.append( + SortOrder(UnresolvedAttribute(columnName), direction, NullsFirst, Seq.empty)) + case 2 => orderBy.append( + SortOrder(UnresolvedAttribute(columnName), direction, NullsLast, Seq.empty)) } - n } + PythonUDTFAnalyzeResult( + schema = schema, + withSinglePartition = withSinglePartition, + partitionByExpressions = partitionByColumns.toSeq, + orderByExpressions = orderBy.toSeq, + pickledAnalyzeResult = pickledAnalyzeResult) } }