Skip to content

Commit

Permalink
[SPARK-44398][CONNECT] Scala foreachBatch API
Browse files Browse the repository at this point in the history
This implements Scala foreachBatch(). The implementation basic and needs some more enhancements. The server side will be shared by Python implementation as well.

One notable hack in this PR is that it runs user's `foreachBatch()` with regular(legacy) DataFrame, rather than setting up remote Spark connect session and connect DataFrame.

### Why are the changes needed?
Adds foreachBatch() support in Scala Spark Connect.

### Does this PR introduce _any_ user-facing change?
Yes. Adds foreachBatch() API

### How was this patch tested?
- A simple unit test.

Closes apache#41969 from rangadi/feb-scala.

Authored-by: Raghu Angadi <raghu.angadi@databricks.com>
Signed-off-by: Xinrong Meng <xinrong@apache.org>
  • Loading branch information
rangadi authored and xinrong-meng committed Jul 13, 2023
1 parent 7e0679b commit 4771853
Show file tree
Hide file tree
Showing 8 changed files with 251 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,15 @@ import org.apache.spark.connect.proto.Command
import org.apache.spark.connect.proto.WriteStreamOperationStart
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Dataset, ForeachWriter}
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
import org.apache.spark.sql.connect.common.ForeachWriterPacket
import org.apache.spark.sql.execution.streaming.AvailableNowTrigger
import org.apache.spark.sql.execution.streaming.ContinuousTrigger
import org.apache.spark.sql.execution.streaming.OneTimeTrigger
import org.apache.spark.sql.execution.streaming.ProcessingTimeTrigger
import org.apache.spark.sql.types.NullType
import org.apache.spark.util.SparkSerDeUtils
import org.apache.spark.util.Utils

/**
* Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems,
Expand Down Expand Up @@ -218,7 +221,30 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging {
val scalaWriterBuilder = proto.ScalarScalaUDF
.newBuilder()
.setPayload(ByteString.copyFrom(serialized))
sinkBuilder.getForeachWriterBuilder.setScalaWriter(scalaWriterBuilder)
sinkBuilder.getForeachWriterBuilder.setScalaFunction(scalaWriterBuilder)
this
}

/**
* :: Experimental ::
*
* (Scala-specific) Sets the output of the streaming query to be processed using the provided
* function. This is supported only in the micro-batch execution modes (that is, when the
* trigger is not continuous). In every micro-batch, the provided function will be called in
* every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. The
* batchId can be used to deduplicate and transactionally write the output (that is, the
* provided Dataset) to external systems. The output Dataset is guaranteed to be exactly the
* same for the same batchId (assuming all operations are deterministic in the query).
*
* @since 3.5.0
*/
@Evolving
def foreachBatch(function: (Dataset[T], Long) => Unit): DataStreamWriter[T] = {
val serializedFn = Utils.serialize(function)
sinkBuilder.getForeachBatchBuilder.getScalaFunctionBuilder
.setPayload(ByteString.copyFrom(serializedFn))
.setOutputType(DataTypeProtoConverter.toConnectProtoType(NullType)) // Unused.
.setNullable(true) // Unused.
this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@ import org.scalatest.concurrent.Eventually.eventually
import org.scalatest.concurrent.Futures.timeout
import org.scalatest.time.SpanSugar._

import org.apache.spark.sql.{ForeachWriter, Row, SparkSession, SQLHelper}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, ForeachWriter, Row, SparkSession, SQLHelper}
import org.apache.spark.sql.connect.client.util.RemoteSparkSession
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.window
import org.apache.spark.util.Utils

class StreamingQuerySuite extends RemoteSparkSession with SQLHelper {
class StreamingQuerySuite extends RemoteSparkSession with SQLHelper with Logging {

test("Streaming API with windowed aggregate query") {
// This verifies standard streaming API by starting a streaming query with windowed count.
Expand Down Expand Up @@ -114,7 +115,7 @@ class StreamingQuerySuite extends RemoteSparkSession with SQLHelper {
withSQLConf(
"spark.sql.shuffle.partitions" -> "1" // Avoid too many reducers.
) {
spark.sql("DROP TABLE IF EXISTS my_table")
spark.sql("DROP TABLE IF EXISTS my_table").collect()

withTempPath { ckpt =>
val q1 = spark.readStream
Expand Down Expand Up @@ -266,6 +267,42 @@ class StreamingQuerySuite extends RemoteSparkSession with SQLHelper {
q.stop()
assert(!q1.isActive)
}

test("foreachBatch") {
// Starts a streaming query with a foreachBatch function, which writes batchId and row count
// to a temp view. The test verifies that the view is populated with data.

val viewName = "test_view"
val tableName = s"global_temp.$viewName"

withTable(tableName) {
val q = spark.readStream
.format("rate")
.option("rowsPerSecond", "10")
.option("numPartitions", "1")
.load()
.writeStream
.foreachBatch(new ForeachBatchFn(viewName))
.start()

eventually(timeout(30.seconds)) { // Wait for first progress.
assert(q.lastProgress != null)
assert(q.lastProgress.numInputRows > 0)
}

eventually(timeout(30.seconds)) {
// There should be row(s) in temporary view created by foreachBatch.
val rows = spark
.sql(s"select * from $tableName")
.collect()
.toSeq
assert(rows.size > 0)
log.info(s"Rows in $tableName: $rows")
}

q.stop()
}
}
}

class TestForeachWriter[T] extends ForeachWriter[T] {
Expand All @@ -292,3 +329,12 @@ class TestForeachWriter[T] extends ForeachWriter[T] {
case class TestClass(value: Int) {
override def toString: String = value.toString
}

class ForeachBatchFn(val viewName: String) extends ((DataFrame, Long) => Unit) with Serializable {
override def apply(df: DataFrame, batchId: Long): Unit = {
val count = df.count()
df.sparkSession
.createDataFrame(Seq((batchId, count)))
.createOrReplaceGlobalTempView(viewName)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,14 @@ message WriteStreamOperationStart {
string table_name = 12;
}

StreamingForeachWriter foreach_writer = 13;
StreamingForeachFunction foreach_writer = 13;
StreamingForeachFunction foreach_batch = 14;
}

message StreamingForeachWriter {
oneof writer {
PythonUDF python_writer = 1;
ScalarScalaUDF scala_writer = 2;
message StreamingForeachFunction {
oneof function {
PythonUDF python_function = 1;
ScalarScalaUDF scala_function = 2;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{ExecutePlanResponse, SqlCommand, StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, WriteStreamOperationStart, WriteStreamOperationStartResult}
import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult
import org.apache.spark.connect.proto.Parse.ParseFormat
import org.apache.spark.connect.proto.StreamingForeachFunction
import org.apache.spark.connect.proto.StreamingQueryManagerCommand
import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult
import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance
Expand Down Expand Up @@ -2661,13 +2662,13 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
}

if (writeOp.hasForeachWriter) {
if (writeOp.getForeachWriter.hasPythonWriter) {
val foreach = writeOp.getForeachWriter.getPythonWriter
if (writeOp.getForeachWriter.hasPythonFunction) {
val foreach = writeOp.getForeachWriter.getPythonFunction
val pythonFcn = transformPythonFunction(foreach)
writer.foreachImplementation(
new PythonForeachWriter(pythonFcn, dataset.schema).asInstanceOf[ForeachWriter[Any]])
} else {
val foreachWriterPkt = unpackForeachWriter(writeOp.getForeachWriter.getScalaWriter)
val foreachWriterPkt = unpackForeachWriter(writeOp.getForeachWriter.getScalaFunction)
val clientWriter = foreachWriterPkt.foreachWriter
val encoder: Option[ExpressionEncoder[Any]] = Try(
ExpressionEncoder(
Expand All @@ -2676,6 +2677,24 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
}
}

if (writeOp.hasForeachBatch) {
val foreachBatchFn = writeOp.getForeachBatch.getFunctionCase match {
case StreamingForeachFunction.FunctionCase.PYTHON_FUNCTION =>
throw InvalidPlanInput("Python ForeachBatch is not supported yet. WIP.")

case StreamingForeachFunction.FunctionCase.SCALA_FUNCTION =>
val scalaFn = Utils.deserialize[StreamingForeachBatchHelper.ForeachBatchFnType](
writeOp.getForeachBatch.getScalaFunction.getPayload.toByteArray,
Utils.getContextOrSparkClassLoader)
StreamingForeachBatchHelper.scalaForeachBatchWrapper(scalaFn, sessionHolder)

case StreamingForeachFunction.FunctionCase.FUNCTION_NOT_SET =>
throw InvalidPlanInput("Unexpected")
}

writer.foreachBatch(foreachBatchFn)
}

val query = writeOp.getPath match {
case "" if writeOp.hasTableName => writer.toTable(writeOp.getTableName)
case "" => writer.start()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connect.planner

import java.util.UUID

import org.apache.spark.internal.Logging
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.connect.service.SessionHolder

/**
* A helper class for handling ForeachBatch related functionality in Spark Connect servers
*/
object StreamingForeachBatchHelper extends Logging {

type ForeachBatchFnType = (DataFrame, Long) => Unit

/**
* Return a new ForeachBatch function that wraps `fn`. It sets up DataFrame cache so that the
* user function can access it. The cache is cleared once ForeachBatch returns.
*/
def dataFrameCachingWrapper(
fn: ForeachBatchFnType,
sessionHolder: SessionHolder): ForeachBatchFnType = { (df: DataFrame, batchId: Long) =>
{
val dfId = UUID.randomUUID().toString
log.info(s"Caching DataFrame with id $dfId") // TODO: Add query id to the log.

// TODO: Sanity check there is no other active DataFrame for this query. Need to include
// query id available in the cache for this check.

sessionHolder.cacheDataFrameById(dfId, df)
try {
fn(df, batchId)
} finally {
log.info(s"Removing DataFrame with id $dfId from the cache")
sessionHolder.removeCachedDataFrame(dfId)
}
}
}

/**
* Handles setting up Scala remote session and other Spark Connect environment and then runs the
* provided foreachBatch function `fn`.
*
* HACK ALERT: This version does not atually set up Spark connect. Directly passes the
* DataFrame, so the user code actually runs with legacy DataFrame.
*/
def scalaForeachBatchWrapper(
fn: ForeachBatchFnType,
sessionHolder: SessionHolder): ForeachBatchFnType = {
// TODO: Set up Spark Connect session. Do we actually need this for the first version?
dataFrameCachingWrapper(fn, sessionHolder)
}
}
Loading

0 comments on commit 4771853

Please sign in to comment.