Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/apache/spark into remove_…
Browse files Browse the repository at this point in the history
…special_casing
  • Loading branch information
itholic committed Sep 4, 2024
2 parents b536c6e + 4a79e73 commit ae98ed7
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
package org.apache.spark.sql

import java.net.URI
import java.nio.file.{Files, Paths}
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.{AtomicLong, AtomicReference}

import scala.jdk.CollectionConverters._
import scala.reflect.runtime.universe.TypeTag
import scala.util.Try

import com.google.common.cache.{CacheBuilder, CacheLoader}
import io.grpc.ClientInterceptor
Expand Down Expand Up @@ -591,6 +593,10 @@ class SparkSession private[sql] (
object SparkSession extends Logging {
private val MAX_CACHED_SESSIONS = 100
private val planIdGenerator = new AtomicLong
private var server: Option[Process] = None
private[sql] val sparkOptions = sys.props.filter { p =>
p._1.startsWith("spark.") && p._2.nonEmpty
}.toMap

private val sessions = CacheBuilder
.newBuilder()
Expand Down Expand Up @@ -623,6 +629,51 @@ object SparkSession extends Logging {
}
}

/**
* Create a new Spark Connect server to connect locally.
*/
private[sql] def withLocalConnectServer[T](f: => T): T = {
synchronized {
val remoteString = sparkOptions
.get("spark.remote")
.orElse(Option(System.getProperty("spark.remote"))) // Set from Spark Submit
.orElse(sys.env.get(SparkConnectClient.SPARK_REMOTE))

val maybeConnectScript =
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))

if (server.isEmpty &&
remoteString.exists(_.startsWith("local")) &&
maybeConnectScript.exists(Files.exists(_))) {
server = Some {
val args =
Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions
.filter(p => !p._1.startsWith("spark.remote"))
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
val pb = new ProcessBuilder(args: _*)
// So don't exclude spark-sql jar in classpath
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
pb.start()
}

// Let the server start. We will directly request to set the configurations
// and this sleep makes less noisy with retries.
Thread.sleep(2000L)
System.setProperty("spark.remote", "sc://localhost")

// scalastyle:off runtimeaddshutdownhook
Runtime.getRuntime.addShutdownHook(new Thread() {
override def run(): Unit = if (server.isDefined) {
new ProcessBuilder(maybeConnectScript.get.toString)
.start()
}
})
// scalastyle:on runtimeaddshutdownhook
}
}
f
}

/**
* Create a new [[SparkSession]] based on the connect client [[Configuration]].
*/
Expand Down Expand Up @@ -765,6 +816,16 @@ object SparkSession extends Logging {
}

private def applyOptions(session: SparkSession): Unit = {
// Only attempts to set Spark SQL configurations.
// If the configurations are static, it might throw an exception so
// simply ignore it for now.
sparkOptions
.filter { case (k, _) =>
k.startsWith("spark.sql.")
}
.foreach { case (key, value) =>
Try(session.conf.set(key, value))
}
options.foreach { case (key, value) =>
session.conf.set(key, value)
}
Expand All @@ -787,7 +848,7 @@ object SparkSession extends Logging {
*
* @since 3.5.0
*/
def create(): SparkSession = {
def create(): SparkSession = withLocalConnectServer {
val session = tryCreateSessionFromClient()
.getOrElse(SparkSession.this.create(builder.configuration))
setDefaultAndActiveSession(session)
Expand All @@ -807,7 +868,7 @@ object SparkSession extends Logging {
*
* @since 3.5.0
*/
def getOrCreate(): SparkSession = {
def getOrCreate(): SparkSession = withLocalConnectServer {
val session = tryCreateSessionFromClient()
.getOrElse({
var existingSession = sessions.get(builder.configuration)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
package org.apache.spark.sql.application

import java.io.{InputStream, OutputStream}
import java.nio.file.Paths
import java.util.concurrent.Semaphore

import scala.util.Try
import scala.util.control.NonFatal

import ammonite.compiler.CodeClassWrapper
Expand All @@ -34,6 +32,7 @@ import ammonite.util.Util.newLine
import org.apache.spark.SparkBuildInfo.spark_version
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.SparkSession.withLocalConnectServer
import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkConnectClientParser}

/**
Expand Down Expand Up @@ -64,37 +63,7 @@ Spark session available as 'spark'.
semaphore: Option[Semaphore] = None,
inputStream: InputStream = System.in,
outputStream: OutputStream = System.out,
errorStream: OutputStream = System.err): Unit = {
val configs: Map[String, String] =
sys.props
.filter(p =>
p._1.startsWith("spark.") &&
p._2.nonEmpty &&
// Don't include spark.remote that we manually set later.
!p._1.startsWith("spark.remote"))
.toMap

val remoteString: Option[String] =
Option(System.getProperty("spark.remote")) // Set from Spark Submit
.orElse(sys.env.get(SparkConnectClient.SPARK_REMOTE))

if (remoteString.exists(_.startsWith("local"))) {
server = Some {
val args = Seq(
Paths.get(sparkHome, "sbin", "start-connect-server.sh").toString,
"--master",
remoteString.get) ++ configs.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
val pb = new ProcessBuilder(args: _*)
// So don't exclude spark-sql jar in classpath
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
pb.start()
}
// Let the server start. We will directly request to set the configurations
// and this sleep makes less noisy with retries.
Thread.sleep(2000L)
System.setProperty("spark.remote", "sc://localhost")
}

errorStream: OutputStream = System.err): Unit = withLocalConnectServer {
// Build the client.
val client =
try {
Expand All @@ -118,13 +87,6 @@ Spark session available as 'spark'.

// Build the session.
val spark = SparkSession.builder().client(client).getOrCreate()

// The configurations might not be all runtime configurations.
// Try to set them with ignoring failures for now.
configs
.filter(_._1.startsWith("spark.sql"))
.foreach { case (k, v) => Try(spark.conf.set(k, v)) }

val sparkBind = new Bind("spark", spark)

// Add the proper imports and register a [[ClassFinder]].
Expand Down Expand Up @@ -197,18 +159,12 @@ Spark session available as 'spark'.
}
}
}
try {
if (semaphore.nonEmpty) {
// Used for testing.
main.run(sparkBind, new Bind[Semaphore]("semaphore", semaphore.get))
} else {
main.run(sparkBind)
}
} finally {
if (server.isDefined) {
new ProcessBuilder(Paths.get(sparkHome, "sbin", "stop-connect-server.sh").toString)
.start()
}

if (semaphore.nonEmpty) {
// Used for testing.
main.run(sparkBind, new Bind[Semaphore]("semaphore", semaphore.get))
} else {
main.run(sparkBind)
}
}
}
Expand Down
42 changes: 29 additions & 13 deletions core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._
import scala.util.Try

import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods.{compact, render}

import org.apache.spark.{SparkConf, SparkUserAppException}
import org.apache.spark.api.python.{Py4JServer, PythonUtils}
import org.apache.spark.internal.config._
Expand All @@ -50,18 +53,21 @@ object PythonRunner {
val formattedPythonFile = formatPath(pythonFile)
val formattedPyFiles = resolvePyFiles(formatPaths(pyFiles))

val gatewayServer = new Py4JServer(sparkConf)
var gatewayServer: Option[Py4JServer] = None
if (sparkConf.getOption("spark.remote").isEmpty) {
gatewayServer = Some(new Py4JServer(sparkConf))

val thread = new Thread(() => Utils.logUncaughtExceptions { gatewayServer.start() })
thread.setName("py4j-gateway-init")
thread.setDaemon(true)
thread.start()
val thread = new Thread(() => Utils.logUncaughtExceptions { gatewayServer.get.start() })
thread.setName("py4j-gateway-init")
thread.setDaemon(true)
thread.start()

// Wait until the gateway server has started, so that we know which port is it bound to.
// `gatewayServer.start()` will start a new thread and run the server code there, after
// initializing the socket, so the thread started above will end as soon as the server is
// ready to serve connections.
thread.join()
// Wait until the gateway server has started, so that we know which port is it bound to.
// `gatewayServer.start()` will start a new thread and run the server code there, after
// initializing the socket, so the thread started above will end as soon as the server is
// ready to serve connections.
thread.join()
}

// Build up a PYTHONPATH that includes the Spark assembly (where this class is), the
// python directories in SPARK_HOME (if set), and any files in the pyFiles argument
Expand All @@ -74,12 +80,22 @@ object PythonRunner {
// Launch Python process
val builder = new ProcessBuilder((Seq(pythonExec, formattedPythonFile) ++ otherArgs).asJava)
val env = builder.environment()
if (sparkConf.getOption("spark.remote").nonEmpty) {
// For non-local remote, pass configurations to environment variables so
// Spark Connect client sets them. For local remotes, they will be set
// via Py4J.
val grouped = sparkConf.getAll.toMap.grouped(10).toSeq
env.put("PYSPARK_REMOTE_INIT_CONF_LEN", grouped.length.toString)
grouped.zipWithIndex.foreach { case (group, idx) =>
env.put(s"PYSPARK_REMOTE_INIT_CONF_$idx", compact(render(group)))
}
}
sparkConf.getOption("spark.remote").foreach(url => env.put("SPARK_REMOTE", url))
env.put("PYTHONPATH", pythonPath)
// This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string
env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort)
env.put("PYSPARK_GATEWAY_SECRET", gatewayServer.secret)
gatewayServer.foreach(s => env.put("PYSPARK_GATEWAY_PORT", s.getListeningPort.toString))
gatewayServer.foreach(s => env.put("PYSPARK_GATEWAY_SECRET", s.secret))
// pass conf spark.pyspark.python to python process, the only way to pass info to
// python process is through environment variable.
sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _))
Expand All @@ -103,7 +119,7 @@ object PythonRunner {
throw new SparkUserAppException(exitCode)
}
} finally {
gatewayServer.shutdown()
gatewayServer.foreach(_.shutdown())
}
}

Expand Down
18 changes: 13 additions & 5 deletions core/src/main/scala/org/apache/spark/ui/UIUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -449,16 +449,24 @@ private[spark] object UIUtils extends Logging {
val startRatio = if (total == 0) 0.0 else (boundedStarted.toDouble / total) * 100
val startWidth = "width: %s%%".format(startRatio)

val killTaskReasonText = reasonToNumKilled.toSeq.sortBy(-_._2).map {
case (reason, count) => s" ($count killed: $reason)"
}.mkString
val progressTitle = s"$completed/$total" + {
if (started > 0) s" ($started running)" else ""
} + {
if (failed > 0) s" ($failed failed)" else ""
} + {
if (skipped > 0) s" ($skipped skipped)" else ""
} + killTaskReasonText

<div class="progress">
<span style="text-align:center; position:absolute; width:100%;">
<span style="text-align:center; position:absolute; width:100%;" title={progressTitle}>
{completed}/{total}
{ if (failed == 0 && skipped == 0 && started > 0) s"($started running)" }
{ if (failed > 0) s"($failed failed)" }
{ if (skipped > 0) s"($skipped skipped)" }
{ reasonToNumKilled.toSeq.sortBy(-_._2).map {
case (reason, count) => s"($count killed: $reason)"
}
}
{ killTaskReasonText }
</span>
<div class="progress-bar progress-completed" style={completeWidth}></div>
<div class="progress-bar progress-started" style={startWidth}></div>
Expand Down
17 changes: 17 additions & 0 deletions core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1802,6 +1802,23 @@ class SparkSubmitSuite
val (_, classpath, _, _) = submit.prepareSubmitEnvironment(appArgs)
assert(classpath.contains("."))
}

// Requires Python dependencies for Spark Connect. Should be enabled by default.
ignore("Spark Connect application submission (Python)") {
val pyFile = File.createTempFile("remote_test", ".py")
pyFile.deleteOnExit()
val content =
"from pyspark.sql import SparkSession;" +
"spark = SparkSession.builder.getOrCreate();" +
"assert 'connect' in str(type(spark));" +
"assert spark.range(1).first()[0] == 0"
FileUtils.write(pyFile, content, StandardCharsets.UTF_8)
val args = Seq(
"--name", "testPyApp",
"--remote", "local",
pyFile.getAbsolutePath)
runSparkSubmit(args)
}
}

object JarCreationTest extends Logging {
Expand Down
2 changes: 1 addition & 1 deletion docs/rdd-programming-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Spark applications in Python can either be run with the `bin/spark-submit` scrip

{% highlight python %}
install_requires=[
'pyspark=={site.SPARK_VERSION}'
'pyspark=={{site.SPARK_VERSION}}'
]
{% endhighlight %}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ List<String> buildClassPath(String appClassPath) throws IOException {
addToClassPath(cp, f.toString());
}
}
if (isRemote && "1".equals(getenv("SPARK_SCALA_SHELL"))) {
// If we're in 'spark.local.connect', it should create a Spark Classic Spark Context
// that launches Spark Connect server.
if (isRemote && System.getenv("SPARK_LOCAL_CONNECT") == null) {
for (File f: new File(jarsDir).listFiles()) {
// Exclude Spark Classic SQL and Spark Connect server jars
// if we're in Spark Connect Shell. Also exclude Spark SQL API and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,6 @@ public List<String> buildCommand(Map<String, String> env)
javaOptsKeys.add("SPARK_BEELINE_OPTS");
yield "SPARK_BEELINE_MEMORY";
}
case "org.apache.spark.sql.application.ConnectRepl" -> {
isRemote = true;
yield "SPARK_DRIVER_MEMORY";
}
default -> "SPARK_DRIVER_MEMORY";
};

Expand Down
Loading

0 comments on commit ae98ed7

Please sign in to comment.