Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-44398][CONNECT] Scala foreachBatch API #41969

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,15 @@ import org.apache.spark.connect.proto.Command
import org.apache.spark.connect.proto.WriteStreamOperationStart
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Dataset, ForeachWriter}
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
import org.apache.spark.sql.connect.common.ForeachWriterPacket
import org.apache.spark.sql.execution.streaming.AvailableNowTrigger
import org.apache.spark.sql.execution.streaming.ContinuousTrigger
import org.apache.spark.sql.execution.streaming.OneTimeTrigger
import org.apache.spark.sql.execution.streaming.ProcessingTimeTrigger
import org.apache.spark.sql.types.NullType
import org.apache.spark.util.SparkSerDeUtils
import org.apache.spark.util.Utils

/**
* Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems,
Expand Down Expand Up @@ -218,7 +221,30 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging {
val scalaWriterBuilder = proto.ScalarScalaUDF
.newBuilder()
.setPayload(ByteString.copyFrom(serialized))
sinkBuilder.getForeachWriterBuilder.setScalaWriter(scalaWriterBuilder)
sinkBuilder.getForeachWriterBuilder.setScalaFunction(scalaWriterBuilder)
this
}

/**
* :: Experimental ::
*
* (Scala-specific) Sets the output of the streaming query to be processed using the provided
* function. This is supported only in the micro-batch execution modes (that is, when the
* trigger is not continuous). In every micro-batch, the provided function will be called in
* every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. The
* batchId can be used to deduplicate and transactionally write the output (that is, the
* provided Dataset) to external systems. The output Dataset is guaranteed to be exactly the
* same for the same batchId (assuming all operations are deterministic in the query).
*
* @since 3.5.0
*/
@Evolving
def foreachBatch(function: (Dataset[T], Long) => Unit): DataStreamWriter[T] = {
val serializedFn = Utils.serialize(function)
sinkBuilder.getForeachBatchBuilder.getScalaFunctionBuilder
.setPayload(ByteString.copyFrom(serializedFn))
.setOutputType(DataTypeProtoConverter.toConnectProtoType(NullType)) // Unused.
.setNullable(true) // Unused.
this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@ import org.scalatest.concurrent.Eventually.eventually
import org.scalatest.concurrent.Futures.timeout
import org.scalatest.time.SpanSugar._

import org.apache.spark.sql.{ForeachWriter, Row, SparkSession, SQLHelper}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, ForeachWriter, Row, SparkSession, SQLHelper}
import org.apache.spark.sql.connect.client.util.RemoteSparkSession
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.window
import org.apache.spark.util.Utils

class StreamingQuerySuite extends RemoteSparkSession with SQLHelper {
class StreamingQuerySuite extends RemoteSparkSession with SQLHelper with Logging {

test("Streaming API with windowed aggregate query") {
// This verifies standard streaming API by starting a streaming query with windowed count.
Expand Down Expand Up @@ -114,7 +115,7 @@ class StreamingQuerySuite extends RemoteSparkSession with SQLHelper {
withSQLConf(
"spark.sql.shuffle.partitions" -> "1" // Avoid too many reducers.
) {
spark.sql("DROP TABLE IF EXISTS my_table")
spark.sql("DROP TABLE IF EXISTS my_table").collect()

withTempPath { ckpt =>
val q1 = spark.readStream
Expand Down Expand Up @@ -266,6 +267,42 @@ class StreamingQuerySuite extends RemoteSparkSession with SQLHelper {
q.stop()
assert(!q1.isActive)
}

test("foreachBatch") {
// Starts a streaming query with a foreachBatch function, which writes batchId and row count
// to a temp view. The test verifies that the view is populated with data.

val viewName = "test_view"
val tableName = s"global_temp.$viewName"

withTable(tableName) {
val q = spark.readStream
.format("rate")
.option("rowsPerSecond", "10")
.option("numPartitions", "1")
.load()
.writeStream
.foreachBatch(new ForeachBatchFn(viewName))
.start()

eventually(timeout(30.seconds)) { // Wait for first progress.
assert(q.lastProgress != null)
assert(q.lastProgress.numInputRows > 0)
}

eventually(timeout(30.seconds)) {
// There should be row(s) in temporary view created by foreachBatch.
val rows = spark
.sql(s"select * from $tableName")
.collect()
.toSeq
assert(rows.size > 0)
log.info(s"Rows in $tableName: $rows")
}

q.stop()
}
}
}

class TestForeachWriter[T] extends ForeachWriter[T] {
Expand All @@ -292,3 +329,12 @@ class TestForeachWriter[T] extends ForeachWriter[T] {
case class TestClass(value: Int) {
override def toString: String = value.toString
}

class ForeachBatchFn(val viewName: String) extends ((DataFrame, Long) => Unit) with Serializable {
override def apply(df: DataFrame, batchId: Long): Unit = {
val count = df.count()
df.sparkSession
.createDataFrame(Seq((batchId, count)))
.createOrReplaceGlobalTempView(viewName)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,14 @@ message WriteStreamOperationStart {
string table_name = 12;
}

StreamingForeachWriter foreach_writer = 13;
StreamingForeachFunction foreach_writer = 13;
StreamingForeachFunction foreach_batch = 14;
rangadi marked this conversation as resolved.
Show resolved Hide resolved
}

message StreamingForeachWriter {
oneof writer {
PythonUDF python_writer = 1;
ScalarScalaUDF scala_writer = 2;
message StreamingForeachFunction {
oneof function {
PythonUDF python_function = 1;
ScalarScalaUDF scala_function = 2;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{ExecutePlanResponse, SqlCommand, StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, WriteStreamOperationStart, WriteStreamOperationStartResult}
import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult
import org.apache.spark.connect.proto.Parse.ParseFormat
import org.apache.spark.connect.proto.StreamingForeachFunction
import org.apache.spark.connect.proto.StreamingQueryManagerCommand
import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult
import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance
Expand Down Expand Up @@ -2661,13 +2662,13 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
}

if (writeOp.hasForeachWriter) {
if (writeOp.getForeachWriter.hasPythonWriter) {
val foreach = writeOp.getForeachWriter.getPythonWriter
if (writeOp.getForeachWriter.hasPythonFunction) {
val foreach = writeOp.getForeachWriter.getPythonFunction
val pythonFcn = transformPythonFunction(foreach)
writer.foreachImplementation(
new PythonForeachWriter(pythonFcn, dataset.schema).asInstanceOf[ForeachWriter[Any]])
} else {
val foreachWriterPkt = unpackForeachWriter(writeOp.getForeachWriter.getScalaWriter)
val foreachWriterPkt = unpackForeachWriter(writeOp.getForeachWriter.getScalaFunction)
val clientWriter = foreachWriterPkt.foreachWriter
val encoder: Option[ExpressionEncoder[Any]] = Try(
ExpressionEncoder(
Expand All @@ -2676,6 +2677,24 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
}
}

if (writeOp.hasForeachBatch) {
val foreachBatchFn = writeOp.getForeachBatch.getFunctionCase match {
case StreamingForeachFunction.FunctionCase.PYTHON_FUNCTION =>
throw InvalidPlanInput("Python ForeachBatch is not supported yet. WIP.")

case StreamingForeachFunction.FunctionCase.SCALA_FUNCTION =>
val scalaFn = Utils.deserialize[StreamingForeachBatchHelper.ForeachBatchFnType](
writeOp.getForeachBatch.getScalaFunction.getPayload.toByteArray,
Utils.getContextOrSparkClassLoader)
StreamingForeachBatchHelper.scalaForeachBatchWrapper(scalaFn, sessionHolder)

case StreamingForeachFunction.FunctionCase.FUNCTION_NOT_SET =>
throw InvalidPlanInput("Unexpected")
}

writer.foreachBatch(foreachBatchFn)
}

val query = writeOp.getPath match {
case "" if writeOp.hasTableName => writer.toTable(writeOp.getTableName)
case "" => writer.start()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connect.planner

import java.util.UUID

import org.apache.spark.internal.Logging
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.connect.service.SessionHolder

/**
* A helper class for handling ForeachBatch related functionality in Spark Connect servers
*/
object StreamingForeachBatchHelper extends Logging {

type ForeachBatchFnType = (DataFrame, Long) => Unit

/**
* Return a new ForeachBatch function that wraps `fn`. It sets up DataFrame cache so that the
* user function can access it. The cache is cleared once ForeachBatch returns.
*/
def dataFrameCachingWrapper(
fn: ForeachBatchFnType,
sessionHolder: SessionHolder): ForeachBatchFnType = { (df: DataFrame, batchId: Long) =>
{
val dfId = UUID.randomUUID().toString
log.info(s"Caching DataFrame with id $dfId") // TODO: Add query id to the log.

// TODO: Sanity check there is no other active DataFrame for this query. Need to include
// query id available in the cache for this check.

sessionHolder.cacheDataFrameById(dfId, df)
try {
fn(df, batchId)
} finally {
log.info(s"Removing DataFrame with id $dfId from the cache")
sessionHolder.removeCachedDataFrame(dfId)
}
}
}

/**
* Handles setting up Scala remote session and other Spark Connect environment and then runs the
* provided foreachBatch function `fn`.
*
* HACK ALERT: This version does not atually set up Spark connect. Directly passes the
* DataFrame, so the user code actually runs with legacy DataFrame.
*/
def scalaForeachBatchWrapper(
fn: ForeachBatchFnType,
sessionHolder: SessionHolder): ForeachBatchFnType = {
// TODO: Set up Spark Connect session. Do we actually need this for the first version?
dataFrameCachingWrapper(fn, sessionHolder)
}
}
Loading