Skip to content

Commit

Permalink
[SPARK-45505][PYTHON] Refactor analyzeInPython to make it reusable
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Currently, the `analyzeInPython` method in UserDefinedPythonTableFunction object can starts a Python process in driver and run a Python function in the Python process. This PR aims to refactor this logic into a reusable runner class.

### Why are the changes needed?

To make the code more reusable.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Existing tests

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#43340 from allisonwang-db/spark-45505-refactor-analyze-in-py.

Authored-by: allisonwang-db <allison.wang@databricks.com>
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
  • Loading branch information
allisonwang-db authored and ueshin committed Oct 13, 2023
1 parent e720cce commit 280f6b3
Show file tree
Hide file tree
Showing 4 changed files with 286 additions and 236 deletions.
6 changes: 3 additions & 3 deletions python/pyspark/sql/worker/analyze_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,14 @@ def main(infile: IO, outfile: IO) -> None:
"""
Runs the Python UDTF's `analyze` static method.
This process will be invoked from `UserDefinedPythonTableFunction.analyzeInPython` in JVM
and receive the Python UDTF and its arguments for the `analyze` static method,
This process will be invoked from `UserDefinedPythonTableFunctionAnalyzeRunner.runInPython`
in JVM and receive the Python UDTF and its arguments for the `analyze` static method,
and call the `analyze` static method, and send back a AnalyzeResult as a result of the method.
"""
try:
check_python_version(infile)

memory_limit_mb = int(os.environ.get("PYSPARK_UDTF_ANALYZER_MEMORY_MB", "-1"))
memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
setup_memory_limits(memory_limit_mb)

setup_spark_files(infile)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3008,14 +3008,14 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val PYTHON_TABLE_UDF_ANALYZER_MEMORY =
buildConf("spark.sql.analyzer.pythonUDTF.analyzeInPython.memory")
.doc("The amount of memory to be allocated to PySpark for Python UDTF analyzer, in MiB " +
"unless otherwise specified. If set, PySpark memory for Python UDTF analyzer will be " +
"limited to this amount. If not set, Spark will not limit Python's " +
"memory use and it is up to the application to avoid exceeding the overhead memory space " +
"shared with other non-JVM processes.\nNote: Windows does not support resource limiting " +
"and actual resource is not limited on MacOS.")
val PYTHON_PLANNER_EXEC_MEMORY =
buildConf("spark.sql.planner.pythonExecution.memory")
.doc("Specifies the memory allocation for executing Python code in Spark driver, in MiB. " +
"When set, it caps the memory for Python execution to the specified amount. " +
"If not set, Spark will not limit Python's memory usage and it is up to the application " +
"to avoid exceeding the overhead memory space shared with other non-JVM processes.\n" +
"Note: Windows does not support resource limiting and actual resource is not limited " +
"on MacOS.")
.version("4.0.0")
.bytesConf(ByteUnit.MiB)
.createOptional
Expand Down Expand Up @@ -5157,7 +5157,7 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def pysparkWorkerPythonExecutable: Option[String] =
getConf(SQLConf.PYSPARK_WORKER_PYTHON_EXECUTABLE)

def pythonUDTFAnalyzerMemory: Option[Long] = getConf(PYTHON_TABLE_UDF_ANALYZER_MEMORY)
def pythonPlannerExecMemory: Option[Long] = getConf(PYTHON_PLANNER_EXEC_MEMORY)

def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.python

import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException, InputStream}
import java.nio.ByteBuffer
import java.nio.channels.SelectionKey
import java.util.HashMap

import scala.jdk.CollectionConverters._

import net.razorvine.pickle.Pickler

import org.apache.spark.{JobArtifactSet, SparkEnv, SparkException}
import org.apache.spark.api.python.{PythonFunction, PythonWorker, PythonWorkerUtils, SpecialLengths}
import org.apache.spark.internal.config.BUFFER_SIZE
import org.apache.spark.internal.config.Python._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.DirectByteBufferOutputStream

/**
* A helper class to run Python functions in Spark driver.
*/
abstract class PythonPlannerRunner[T](func: PythonFunction) {

protected val workerModule: String

protected def writeToPython(dataOut: DataOutputStream, pickler: Pickler): Unit

protected def receiveFromPython(dataIn: DataInputStream): T

def runInPython(): T = {
val env = SparkEnv.get
val bufferSize: Int = env.conf.get(BUFFER_SIZE)
val authSocketTimeout = env.conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
val reuseWorker = env.conf.get(PYTHON_WORKER_REUSE)
val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",")
val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
val workerMemoryMb = SQLConf.get.pythonPlannerExecMemory

val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)

val envVars = new HashMap[String, String](func.envVars)
val pythonExec = func.pythonExec
val pythonVer = func.pythonVer
val pythonIncludes = func.pythonIncludes.asScala.toSet
val broadcastVars = func.broadcastVars.asScala.toSeq
val maybeAccumulator = Option(func.accumulator).map(_.copyAndReset())

envVars.put("SPARK_LOCAL_DIRS", localdir)
if (reuseWorker) {
envVars.put("SPARK_REUSE_WORKER", "1")
}
if (simplifiedTraceback) {
envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1")
}
workerMemoryMb.foreach { memoryMb =>
envVars.put("PYSPARK_PLANNER_MEMORY_MB", memoryMb.toString)
}
envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)

envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default"))

EvaluatePython.registerPicklers()
val pickler = new Pickler(/* useMemo = */ true,
/* valueCompare = */ false)

val (worker: PythonWorker, _) =
env.createPythonWorker(pythonExec, workerModule, envVars.asScala.toMap)
var releasedOrClosed = false
val bufferStream = new DirectByteBufferOutputStream()
try {
val dataOut = new DataOutputStream(new BufferedOutputStream(bufferStream, bufferSize))

PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)
PythonWorkerUtils.writeSparkFiles(jobArtifactUUID, pythonIncludes, dataOut)
PythonWorkerUtils.writeBroadcasts(broadcastVars, worker, env, dataOut)

writeToPython(dataOut, pickler)

dataOut.writeInt(SpecialLengths.END_OF_STREAM)
dataOut.flush()

val dataIn = new DataInputStream(new BufferedInputStream(
new WorkerInputStream(worker, bufferStream.toByteBuffer), bufferSize))

val res = receiveFromPython(dataIn)

PythonWorkerUtils.receiveAccumulatorUpdates(maybeAccumulator, dataIn)
Option(func.accumulator).foreach(_.merge(maybeAccumulator.get))

dataIn.readInt() match {
case SpecialLengths.END_OF_STREAM if reuseWorker =>
env.releasePythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker)
case _ =>
env.destroyPythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker)
}
releasedOrClosed = true

res
} catch {
case eof: EOFException =>
throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
} finally {
try {
bufferStream.close()
} finally {
if (!releasedOrClosed) {
// An error happened. Force to close the worker.
env.destroyPythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker)
}
}
}
}

/**
* A wrapper of the non-blocking IO to write to/read from the worker.
*
* Since we use non-blocking IO to communicate with workers; see SPARK-44705,
* a wrapper is needed to do IO with the worker.
* This is a port and simplified version of `PythonRunner.ReaderInputStream`,
* and only supports to write all at once and then read all.
*/
private class WorkerInputStream(worker: PythonWorker, buffer: ByteBuffer) extends InputStream {

private[this] val temp = new Array[Byte](1)

override def read(): Int = {
val n = read(temp)
if (n <= 0) {
-1
} else {
// Signed byte to unsigned integer
temp(0) & 0xff
}
}

override def read(b: Array[Byte], off: Int, len: Int): Int = {
val buf = ByteBuffer.wrap(b, off, len)
var n = 0
while (n == 0) {
worker.selector.select()
if (worker.selectionKey.isReadable) {
n = worker.channel.read(buf)
}
if (worker.selectionKey.isWritable) {
var acceptsInput = true
while (acceptsInput && buffer.hasRemaining) {
val n = worker.channel.write(buffer)
acceptsInput = n > 0
}
if (!buffer.hasRemaining) {
// We no longer have any data to write to the socket.
worker.selectionKey.interestOps(SelectionKey.OP_READ)
}
}
}
n
}
}
}
Loading

0 comments on commit 280f6b3

Please sign in to comment.