Skip to content

Commit

Permalink
Merge pull request #1519 from apache/master
Browse files Browse the repository at this point in the history
Create a new pull request by comparing changes across two branches
  • Loading branch information
GulajavaMinistudio authored Jul 8, 2023
2 parents 7abf7eb + 57bbb4c commit 4e4d26f
Show file tree
Hide file tree
Showing 67 changed files with 2,135 additions and 624 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util

import java.io.{ByteArrayOutputStream, ObjectOutputStream}

object SparkSerDerseUtils {
/** Serialize an object using Java serialization */
def serialize[T](o: T): Array[Byte] = {
val bos = new ByteArrayOutputStream()
val oos = new ObjectOutputStream(bos)
oos.writeObject(o)
oos.close()
bos.toByteArray
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ import java.util.concurrent.TimeUnit
import scala.concurrent.duration.Duration

import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_DAY
import org.apache.spark.sql.catalyst.util.DateTimeUtils.microsToMillis
import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils.microsToMillis
import org.apache.spark.sql.catalyst.util.SparkIntervalUtils
import org.apache.spark.sql.streaming.Trigger
import org.apache.spark.unsafe.types.UTF8String

Expand All @@ -35,7 +35,7 @@ private object Triggers {
}

def convert(interval: String): Long = {
val cal = IntervalUtils.stringToInterval(UTF8String.fromString(interval))
val cal = SparkIntervalUtils.stringToInterval(UTF8String.fromString(interval))
if (cal.months != 0) {
throw new IllegalArgumentException(s"Doesn't support month or year interval: $interval")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, UdfPacket}
import org.apache.spark.util.Utils
import org.apache.spark.util.SparkSerDerseUtils

/**
* A user-defined function. To create one, use the `udf` functions in `functions`.
Expand Down Expand Up @@ -103,7 +103,8 @@ case class ScalarUserDefinedFunction(

// SPARK-43198: Eagerly serialize to prevent the UDF from containing a reference to this class.
private[this] val udf = {
val udfPacketBytes = Utils.serialize(UdfPacket(function, inputEncoders, outputEncoder))
val udfPacketBytes =
SparkSerDerseUtils.serialize(UdfPacket(function, inputEncoders, outputEncoder))
val scalaUdfBuilder = proto.ScalarScalaUDF
.newBuilder()
.setPayload(ByteString.copyFrom(udfPacketBytes))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1252,6 +1252,14 @@ object functions {
*/
def some(e: Column): Column = Column.fn("some", e)

/**
* Aggregate function: returns true if at least one value of `e` is true.
*
* @group agg_funcs
* @since 3.5.0
*/
def any(e: Column): Column = Column.fn("any", e)

/**
* Aggregate function: returns true if at least one value of `e` is true.
*
Expand Down Expand Up @@ -3686,6 +3694,16 @@ object functions {
*/
def length(e: Column): Column = Column.fn("length", e)

/**
* Computes the character length of a given string or number of bytes of a binary string. The
* length of character strings include the trailing spaces. The length of binary strings
* includes binary zeros.
*
* @group string_funcs
* @since 3.5.0
*/
def len(e: Column): Column = Column.fn("len", e)

/**
* Converts a string column to lower case.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.streaming.AvailableNowTrigger
import org.apache.spark.sql.execution.streaming.ContinuousTrigger
import org.apache.spark.sql.execution.streaming.OneTimeTrigger
import org.apache.spark.sql.execution.streaming.ProcessingTimeTrigger
import org.apache.spark.util.Utils
import org.apache.spark.util.SparkSerDerseUtils

/**
* Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems,
Expand Down Expand Up @@ -214,7 +214,7 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging {
* @since 3.5.0
*/
def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = {
val serialized = Utils.serialize(ForeachWriterPacket(writer, ds.encoder))
val serialized = SparkSerDerseUtils.serialize(ForeachWriterPacket(writer, ds.encoder))
val scalaWriterBuilder = proto.ScalarScalaUDF
.newBuilder()
.setPayload(ByteString.copyFrom(serialized))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1181,6 +1181,10 @@ class PlanGenerationTestSuite
boolean.select(fn.some(fn.col("flag")))
}

test("function any") {
boolean.select(fn.any(fn.col("flag")))
}

test("function bool_or") {
boolean.select(fn.bool_or(fn.col("flag")))
}
Expand Down Expand Up @@ -1629,6 +1633,10 @@ class PlanGenerationTestSuite
fn.length(fn.col("g"))
}

functionTest("len") {
fn.len(fn.col("g"))
}

functionTest("lower") {
fn.lower(fn.col("g"))
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Aggregate [max(flag#0) AS any(flag)#0]
+- LocalRelation <empty>, [id#0L, flag#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [len(g#0) AS len(g)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"common": {
"planId": "1"
},
"project": {
"input": {
"common": {
"planId": "0"
},
"localRelation": {
"schema": "struct\u003cid:bigint,flag:boolean\u003e"
}
},
"expressions": [{
"unresolvedFunction": {
"functionName": "any",
"arguments": [{
"unresolvedAttribute": {
"unparsedIdentifier": "flag"
}
}]
}
}]
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"common": {
"planId": "1"
},
"project": {
"input": {
"common": {
"planId": "0"
},
"localRelation": {
"schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e"
}
},
"expressions": [{
"unresolvedFunction": {
"functionName": "len",
"arguments": [{
"unresolvedAttribute": {
"unparsedIdentifier": "g"
}
}]
}
}]
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import scala.reflect.ClassTag
import org.apache.commons.io.FileUtils
import org.apache.hadoop.fs.{LocalFileSystem, Path => FSPath}

import org.apache.spark.{JobArtifactSet, SparkContext, SparkEnv}
import org.apache.spark.{JobArtifactState, SparkContext, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.artifact.util.ArtifactUtils
Expand All @@ -56,15 +56,15 @@ import org.apache.spark.util.Utils
class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging {
import SparkConnectArtifactManager._

private val sessionUUID = sessionHolder.session.sessionUUID
// The base directory/URI where all artifacts are stored for this `sessionUUID`.
val (artifactPath, artifactURI): (Path, String) =
getArtifactDirectoryAndUriForSession(sessionHolder)
// The base directory/URI where all class file artifacts are stored for this `sessionUUID`.
val (classDir, classURI): (Path, String) = getClassfileDirectoryAndUriForSession(sessionHolder)
val state: JobArtifactState =
JobArtifactState(sessionHolder.session.sessionUUID, Option(classURI))

private val jarsList = new CopyOnWriteArrayList[Path]
private val jarsURI = new CopyOnWriteArrayList[String]
private val pythonIncludeList = new CopyOnWriteArrayList[String]

/**
Expand Down Expand Up @@ -132,10 +132,16 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging
}
Files.move(serverLocalStagingPath, target)
if (remoteRelativePath.startsWith(s"jars${File.separator}")) {
sessionHolder.session.sessionState.resourceLoader
.addJar(target.toString, state.uuid)
jarsList.add(target)
jarsURI.add(artifactURI + "/" + remoteRelativePath.toString)
} else if (remoteRelativePath.startsWith(s"pyfiles${File.separator}")) {
sessionHolder.session.sparkContext.addFile(target.toString)
sessionHolder.session.sparkContext.addFile(
target.toString,
recursive = false,
addedOnSubmit = false,
isArchive = false,
sessionUUID = state.uuid)
val stringRemotePath = remoteRelativePath.toString
if (stringRemotePath.endsWith(".zip") || stringRemotePath.endsWith(
".egg") || stringRemotePath.endsWith(".jar")) {
Expand All @@ -144,35 +150,28 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging
} else if (remoteRelativePath.startsWith(s"archives${File.separator}")) {
val canonicalUri =
fragment.map(UriBuilder.fromUri(target.toUri).fragment).getOrElse(target.toUri)
sessionHolder.session.sparkContext.addArchive(canonicalUri.toString)
sessionHolder.session.sparkContext.addFile(
canonicalUri.toString,
recursive = false,
addedOnSubmit = false,
isArchive = true,
sessionUUID = state.uuid)
} else if (remoteRelativePath.startsWith(s"files${File.separator}")) {
sessionHolder.session.sparkContext.addFile(target.toString)
sessionHolder.session.sparkContext.addFile(
target.toString,
recursive = false,
addedOnSubmit = false,
isArchive = false,
sessionUUID = state.uuid)
}
}
}

/**
* Returns a [[JobArtifactSet]] pointing towards the session-specific jars and class files.
*/
def jobArtifactSet: JobArtifactSet = {
val builder = Map.newBuilder[String, Long]
jarsURI.forEach { jar =>
builder += jar -> 0
}

new JobArtifactSet(
uuid = Option(sessionUUID),
replClassDirUri = Option(classURI),
jars = builder.result(),
files = Map.empty,
archives = Map.empty)
}

/**
* Returns a [[ClassLoader]] for session-specific jar/class file resources.
*/
def classloader: ClassLoader = {
val urls = jarsList.asScala.map(_.toUri.toURL) :+ classDir.toUri.toURL
val urls = getSparkConnectAddedJars :+ classDir.toUri.toURL
new URLClassLoader(urls.toArray, Utils.getContextOrSparkClassLoader)
}

Expand All @@ -183,6 +182,12 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging
logDebug(
s"Cleaning up resources for session with userId: ${sessionHolder.userId} and " +
s"sessionId: ${sessionHolder.sessionId}")

// Clean up added files
sessionHolder.session.sparkContext.addedFiles.remove(state.uuid)
sessionHolder.session.sparkContext.addedArchives.remove(state.uuid)
sessionHolder.session.sparkContext.addedJars.remove(state.uuid)

// Clean up cached relations
val blockManager = sessionHolder.session.sparkContext.env.blockManager
blockManager.removeCache(sessionHolder.userId, sessionHolder.sessionId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,6 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
artifactManager.addArtifact(remoteRelativePath, serverLocalStagingPath, fragment)
}

/**
* A [[JobArtifactSet]] for this SparkConnect session.
*/
def connectJobArtifactSet: JobArtifactSet = artifactManager.jobArtifactSet

/**
* A [[ClassLoader]] for jar/class file resources specific to this SparkConnect session.
*/
Expand All @@ -114,8 +109,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
def withContextClassLoader[T](f: => T): T = {
// Needed for deserializing and evaluating the UDF on the driver
Utils.withContextClassLoader(classloader) {
// Needed for propagating the dependencies to the executors.
JobArtifactSet.withActive(connectJobArtifactSet) {
JobArtifactSet.withActiveJobArtifactState(artifactManager.state) {
f
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,6 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper {
super.afterEach()
}

test("Jar artifacts are added to spark session") {
val copyDir = Utils.createTempDir().toPath
FileUtils.copyDirectory(artifactPath.toFile, copyDir.toFile)
val stagingPath = copyDir.resolve("smallJar.jar")
val remotePath = Paths.get("jars/smallJar.jar")
artifactManager.addArtifact(remotePath, stagingPath, None)

val expectedPath = SparkConnectArtifactManager.artifactRootPath
.resolve(s"$sessionUUID/jars/smallJar.jar")
assert(expectedPath.toFile.exists())
val jars = artifactManager.jobArtifactSet.jars
assert(jars.exists(_._1.contains(remotePath.toString)))
}

test("Class artifacts are added to the correct directory.") {
val copyDir = Utils.createTempDir().toPath
FileUtils.copyDirectory(artifactPath.toFile, copyDir.toFile)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ <h4 class="title-table">Executors</h4>
<th>Thread Dump</th>
<th>Heap Histogram</th>
<th>Exec Loss Reason</th>
<th>Add Time</th>
<th>Remove Time</th>
</tr>
</thead>
<tbody>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/

/* global $, Mustache, createRESTEndPointForExecutorsPage, createRESTEndPointForMiscellaneousProcess, */
/* global createTemplateURI, formatBytes, formatDuration, formatLogsCells, getStandAloneAppId, */
/* global createTemplateURI, formatBytes, formatDate, formatDuration, formatLogsCells, getStandAloneAppId, */
/* global jQuery, setDataTableDefaults */

var threadDumpEnabled = false;
Expand Down Expand Up @@ -568,6 +568,14 @@ $(document).ready(function () {
{
data: 'removeReason',
render: formatLossReason
},
{
data: 'addTime',
render: formatDate
},
{
data: 'removeTime',
render: formatDate
}
],
"order": [[0, "asc"]],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ function setDataTableDefaults() {
}

function formatDate(date) {
if (date <= 0) return "-";
if (!date || date <= 0) return "-";
else {
var dt = new Date(date.replace("GMT", "Z"));
return formatDateString(dt);
Expand Down
Loading

0 comments on commit 4e4d26f

Please sign in to comment.