Skip to content

Commit

Permalink
Merge pull request #1633 from apache/master
Browse files Browse the repository at this point in the history
Create a new pull request by comparing changes across two branches
  • Loading branch information
GulajavaMinistudio authored Mar 21, 2024
2 parents 6182ec8 + 3a1609a commit 68731cf
Show file tree
Hide file tree
Showing 133 changed files with 5,430 additions and 2,053 deletions.
2 changes: 2 additions & 0 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ SQL:
]
- any-glob-to-any-file: [
'common/unsafe/**/*',
'common/sketch/**/*',
'common/variant/**/*',
'bin/spark-sql*',
'bin/beeline*',
'sbin/*thriftserver*.sh',
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/build_sparkr_window.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
#
name: "Build / SparkR-only (master, 4.3.2, windows-2019)"
name: "Build / SparkR-only (master, 4.3.3, windows-2019)"

on:
schedule:
Expand Down Expand Up @@ -50,10 +50,10 @@ jobs:
with:
distribution: zulu
java-version: 17
- name: Install R 4.3.2
- name: Install R 4.3.3
uses: r-lib/actions/setup-r@v2
with:
r-version: 4.3.2
r-version: 4.3.3
- name: Install R dependencies
run: |
Rscript -e "install.packages(c('knitr', 'rmarkdown', 'testthat', 'e1071', 'survival', 'arrow', 'xml2'), repos='https://cloud.r-project.org/')"
Expand Down
5 changes: 5 additions & 0 deletions common/utils/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -2083,6 +2083,11 @@
"Operation not found."
]
},
"SESSION_CHANGED" : {
"message" : [
"The existing Spark server driver instance has restarted. Please reconnect."
]
},
"SESSION_CLOSED" : {
"message" : [
"Session was closed."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ private[spark] object Logging {
val initLock = new Object()
try {
// We use reflection here to handle the case where users remove the
// slf4j-to-jul bridge order to route their logs to JUL.
// jul-to-slf4j bridge order to route their logs to JUL.
val bridgeClass = SparkClassUtils.classForName("org.slf4j.bridge.SLF4JBridgeHandler")
bridgeClass.getMethod("removeHandlersForRootLogger").invoke(null)
val installed = bridgeClass.getMethod("isInstalled").invoke(null).asInstanceOf[Boolean]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,19 +197,21 @@ private[sql] object AvroUtils extends Logging {

def hasNextRow: Boolean = {
while (!completed && currentRow.isEmpty) {
if (fileReader.pastSync(stopPosition)) {
// In the case of empty blocks in an Avro file, `blockRemaining` could still read as 0 so
// `fileReader.hasNext()` returns false but advances the cursor to the next block, so we
// need to call `fileReader.hasNext()` again to correctly report if the next record
// exists.
val moreData =
(fileReader.hasNext || fileReader.hasNext) && !fileReader.pastSync(stopPosition)
if (!moreData) {
fileReader.close()
completed = true
currentRow = None
} else if (fileReader.hasNext()) {
} else {
val record = fileReader.next()
// the row must be deserialized in hasNextRow, because AvroDeserializer#deserialize
// potentially filters rows
currentRow = deserializer.deserialize(record).asInstanceOf[Option[InternalRow]]
} else {
// In this case, `fileReader.hasNext()` returns false but we are not past sync point yet.
// This means empty blocks, we need to continue reading the file in case there are non
// empty blocks or we are past sync point.
}
}
currentRow.isDefined
Expand Down
Binary file added connector/avro/src/test/resources/empty_file.avro
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -2744,6 +2744,16 @@ abstract class AvroSuite
}
}
}

test("SPARK-46990: read an empty file where pastSync returns false at EOF") {
for (maxPartitionBytes <- Seq(100, 100000, 100000000)) {
withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> s"$maxPartitionBytes") {
val file = getResourceAvroFilePath("empty_file.avro")
val df = spark.read.format("avro").load(file)
assert(df.count() == 0)
}
}
}
}

class AvroV1Suite extends AvroSuite {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ message AnalyzePlanRequest {
// The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff`
string session_id = 1;

// (Optional)
//
// Server-side generated idempotency key from the previous responses (if any). Server
// can use this to validate that the server side session has not changed.
optional string client_observed_server_side_session_id = 17;

// (Required) User context
UserContext user_context = 2;

Expand Down Expand Up @@ -281,6 +287,12 @@ message ExecutePlanRequest {
// The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff`
string session_id = 1;

// (Optional)
//
// Server-side generated idempotency key from the previous responses (if any). Server
// can use this to validate that the server side session has not changed.
optional string client_observed_server_side_session_id = 8;

// (Required) User context
//
// user_context.user_id and session+id both identify a unique remote spark session on the
Expand Down Expand Up @@ -443,6 +455,12 @@ message ConfigRequest {
// The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff`
string session_id = 1;

// (Optional)
//
// Server-side generated idempotency key from the previous responses (if any). Server
// can use this to validate that the server side session has not changed.
optional string client_observed_server_side_session_id = 8;

// (Required) User context
UserContext user_context = 2;

Expand Down Expand Up @@ -536,6 +554,12 @@ message AddArtifactsRequest {
// User context
UserContext user_context = 2;

// (Optional)
//
// Server-side generated idempotency key from the previous responses (if any). Server
// can use this to validate that the server side session has not changed.
optional string client_observed_server_side_session_id = 7;

// Provides optional information about the client sending the request. This field
// can be used for language or version specific information and is only intended for
// logging purposes and will not be interpreted by the server.
Expand Down Expand Up @@ -630,6 +654,12 @@ message ArtifactStatusesRequest {
// The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff`
string session_id = 1;

// (Optional)
//
// Server-side generated idempotency key from the previous responses (if any). Server
// can use this to validate that the server side session has not changed.
optional string client_observed_server_side_session_id = 5;

// User context
UserContext user_context = 2;

Expand Down Expand Up @@ -673,6 +703,12 @@ message InterruptRequest {
// The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff`
string session_id = 1;

// (Optional)
//
// Server-side generated idempotency key from the previous responses (if any). Server
// can use this to validate that the server side session has not changed.
optional string client_observed_server_side_session_id = 7;

// (Required) User context
UserContext user_context = 2;

Expand Down Expand Up @@ -738,6 +774,12 @@ message ReattachExecuteRequest {
// This must be an id of existing session.
string session_id = 1;

// (Optional)
//
// Server-side generated idempotency key from the previous responses (if any). Server
// can use this to validate that the server side session has not changed.
optional string client_observed_server_side_session_id = 6;

// (Required) User context
//
// user_context.user_id and session+id both identify a unique remote spark session on the
Expand Down Expand Up @@ -772,6 +814,12 @@ message ReleaseExecuteRequest {
// This must be an id of existing session.
string session_id = 1;

// (Optional)
//
// Server-side generated idempotency key from the previous responses (if any). Server
// can use this to validate that the server side session has not changed.
optional string client_observed_server_side_session_id = 7;

// (Required) User context
//
// user_context.user_id and session+id both identify a unique remote spark session on the
Expand Down Expand Up @@ -856,6 +904,12 @@ message FetchErrorDetailsRequest {
// The id should be a UUID string of the format `00112233-4455-6677-8899-aabbccddeeff`.
string session_id = 1;

// (Optional)
//
// Server-side generated idempotency key from the previous responses (if any). Server
// can use this to validate that the server side session has not changed.
optional string client_observed_server_side_session_id = 5;

// User context
UserContext user_context = 2;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,27 @@ import io.grpc.stub.StreamObserver

import org.apache.spark.internal.Logging

// This is common logic to be shared between different stub instances to validate responses as
// seen by the client.
// This is common logic to be shared between different stub instances to keep the server-side
// session id and to validate responses as seen by the client.
class ResponseValidator extends Logging {

// Server side session ID, used to detect if the server side session changed. This is set upon
// receiving the first response from the server. This value is used only for executions that
// do not use server-side streaming.
private var serverSideSessionId: Option[String] = None

// Returns the server side session ID, used to send it back to the server in the follow-up
// requests so the server can validate it session id against the previous requests.
def getServerSideSessionId: Option[String] = serverSideSessionId

/**
* Hijacks the stored server side session ID with the given suffix. Used for testing to make
* sure that server is validating the session ID.
*/
private[sql] def hijackServerSideSessionIdForTesting(suffix: String): Unit = {
serverSideSessionId = Some(serverSideSessionId.getOrElse("") + suffix)
}

def verifyResponse[RespT <: GeneratedMessageV3](fn: => RespT): RespT = {
val response = fn
val field = response.getDescriptorForType.findFieldByName("server_side_session_id")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ private[sql] class SparkConnectClient(
// a new client will create a new session ID.
private[sql] val sessionId: String = configuration.sessionId.getOrElse(UUID.randomUUID.toString)

/**
* Hijacks the stored server side session ID with the given suffix. Used for testing to make
* sure that server is validating the session ID.
*/
private[sql] def hijackServerSideSessionIdForTesting(suffix: String) = {
stubState.responseValidator.hijackServerSideSessionIdForTesting(suffix)
}

private[sql] val artifactManager: ArtifactManager = {
new ArtifactManager(configuration, sessionId, bstub, stub)
}
Expand All @@ -73,6 +81,14 @@ private[sql] class SparkConnectClient(
private[sql] def uploadAllClassFileArtifacts(): Unit =
artifactManager.uploadAllClassFileArtifacts()

/**
* Returns the server-side session id obtained from the first request, if there was a request
* already.
*/
private def serverSideSessionId: Option[String] = {
stubState.responseValidator.getServerSideSessionId
}

/**
* Dispatch the [[proto.AnalyzePlanRequest]] to the Spark Connect server.
* @return
Expand All @@ -99,11 +115,11 @@ private[sql] class SparkConnectClient(
.setSessionId(sessionId)
.setClientType(userAgent)
.addAllTags(tags.get.toSeq.asJava)
.build()
serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session))
if (configuration.useReattachableExecute) {
bstub.executePlanReattachable(request)
bstub.executePlanReattachable(request.build())
} else {
bstub.executePlan(request)
bstub.executePlan(request.build())
}
}

Expand All @@ -119,8 +135,8 @@ private[sql] class SparkConnectClient(
.setSessionId(sessionId)
.setClientType(userAgent)
.setUserContext(userContext)
.build()
bstub.config(request)
serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session))
bstub.config(request.build())
}

/**
Expand Down Expand Up @@ -207,8 +223,8 @@ private[sql] class SparkConnectClient(
.setUserContext(userContext)
.setSessionId(sessionId)
.setClientType(userAgent)
.build()
analyze(request)
serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session))
analyze(request.build())
}

private[sql] def interruptAll(): proto.InterruptResponse = {
Expand All @@ -218,8 +234,8 @@ private[sql] class SparkConnectClient(
.setSessionId(sessionId)
.setClientType(userAgent)
.setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_ALL)
.build()
bstub.interrupt(request)
serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session))
bstub.interrupt(request.build())
}

private[sql] def interruptTag(tag: String): proto.InterruptResponse = {
Expand All @@ -230,8 +246,8 @@ private[sql] class SparkConnectClient(
.setClientType(userAgent)
.setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_TAG)
.setOperationTag(tag)
.build()
bstub.interrupt(request)
serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session))
bstub.interrupt(request.build())
}

private[sql] def interruptOperation(id: String): proto.InterruptResponse = {
Expand All @@ -242,8 +258,8 @@ private[sql] class SparkConnectClient(
.setClientType(userAgent)
.setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_OPERATION_ID)
.setOperationId(id)
.build()
bstub.interrupt(request)
serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session))
bstub.interrupt(request.build())
}

private[sql] def releaseSession(): proto.ReleaseSessionResponse = {
Expand All @@ -252,8 +268,7 @@ private[sql] class SparkConnectClient(
.setUserContext(userContext)
.setSessionId(sessionId)
.setClientType(userAgent)
.build()
bstub.releaseSession(request)
bstub.releaseSession(request.build())
}

private[this] val tags = new InheritableThreadLocal[mutable.Set[String]] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,9 @@ import scala.collection.mutable
import scala.reflect.ClassTag

import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.vector.{FieldVector, VarCharVector, VectorSchemaRoot}
import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot}
import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector}
import org.apache.arrow.vector.ipc.ArrowReader
import org.apache.arrow.vector.util.Text

import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
Expand Down Expand Up @@ -468,16 +467,6 @@ object ArrowDeserializers {

private def isTuple(cls: Class[_]): Boolean = cls.getName.startsWith("scala.Tuple")

private def getString(v: VarCharVector, i: Int): String = {
// This is currently a bit heavy on allocations:
// - byte array created in VarCharVector.get
// - CharBuffer created CharSetEncoder
// - char array in String
// By using direct buffers and reusing the char buffer
// we could get rid of the first two allocations.
Text.decode(v.get(i))
}

private def loadListIntoBuilder(
v: ListVector,
i: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,14 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr

override def onNext(req: AddArtifactsRequest): Unit = try {
if (this.holder == null) {
val previousSessionId = req.hasClientObservedServerSideSessionId match {
case true => Some(req.getClientObservedServerSideSessionId)
case false => None
}
this.holder = SparkConnectService.getOrCreateIsolatedSession(
req.getUserContext.getUserId,
req.getSessionId)
req.getSessionId,
previousSessionId)
}

if (req.hasBeginChunk) {
Expand Down
Loading

0 comments on commit 68731cf

Please sign in to comment.