forked from apache/spark
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-44341][SQL][PYTHON] Define the computing logic through Partiti…
…onEvaluator API and use it in WindowExec and WindowInPandasExec ### What changes were proposed in this pull request? `WindowExec` and `WindowInPandasExec` are updated to use the `PartitionEvaluator` API to do execution. ### Why are the changes needed? To define the computing logic and requires the caller side to explicitly list what needs to be serialized and sent to executors ### Does this PR introduce _any_ user-facing change? 'No'. Just update the inner implementation. ### How was this patch tested? Add new test cases. Closes apache#41939 from beliefer/SPARK-44341. Authored-by: Jiaan Geng <beliefer@163.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
- Loading branch information
1 parent
6058eba
commit 23c4e3d
Showing
5 changed files
with
817 additions
and
676 deletions.
There are no files selected for viewing
369 changes: 369 additions & 0 deletions
369
...src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,369 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.execution.python | ||
|
||
import java.io.File | ||
|
||
import scala.collection.JavaConverters._ | ||
import scala.collection.mutable.ArrayBuffer | ||
|
||
import org.apache.spark.{JobArtifactSet, PartitionEvaluator, PartitionEvaluatorFactory, SparkEnv, TaskContext} | ||
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} | ||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, EmptyRow, Expression, JoinedRow, NamedExpression, PythonFuncExpression, PythonUDAF, SortOrder, SpecificInternalRow, UnsafeProjection, UnsafeRow, WindowExpression} | ||
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression | ||
import org.apache.spark.sql.errors.QueryExecutionErrors | ||
import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray | ||
import org.apache.spark.sql.execution.metric.SQLMetric | ||
import org.apache.spark.sql.execution.window.{SlidingWindowFunctionFrame, UnboundedFollowingWindowFunctionFrame, UnboundedPrecedingWindowFunctionFrame, UnboundedWindowFunctionFrame, WindowEvaluatorFactoryBase, WindowFunctionFrame} | ||
import org.apache.spark.sql.internal.SQLConf | ||
import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructType} | ||
import org.apache.spark.sql.util.ArrowUtils | ||
import org.apache.spark.util.Utils | ||
|
||
class WindowInPandasEvaluatorFactory( | ||
val windowExpression: Seq[NamedExpression], | ||
val partitionSpec: Seq[Expression], | ||
val orderSpec: Seq[SortOrder], | ||
val childOutput: Seq[Attribute], | ||
val spillSize: SQLMetric, | ||
pythonMetrics: Map[String, SQLMetric]) | ||
extends PartitionEvaluatorFactory[InternalRow, InternalRow] with WindowEvaluatorFactoryBase { | ||
|
||
/** | ||
* Helper functions and data structures for window bounds | ||
* | ||
* It contains: | ||
* (1) Total number of window bound indices in the python input row | ||
* (2) Function from frame index to its lower bound column index in the python input row | ||
* (3) Function from frame index to its upper bound column index in the python input row | ||
* (4) Seq from frame index to its window bound type | ||
*/ | ||
private type WindowBoundHelpers = (Int, Int => Int, Int => Int, Seq[WindowBoundType]) | ||
|
||
/** | ||
* Enum for window bound types. Used only inside this class. | ||
*/ | ||
private sealed case class WindowBoundType(value: String) | ||
|
||
private object UnboundedWindow extends WindowBoundType("unbounded") | ||
|
||
private object BoundedWindow extends WindowBoundType("bounded") | ||
|
||
private val windowBoundTypeConf = "pandas_window_bound_types" | ||
|
||
private def collectFunctions( | ||
udf: PythonFuncExpression): (ChainedPythonFunctions, Seq[Expression]) = { | ||
udf.children match { | ||
case Seq(u: PythonFuncExpression) => | ||
val (chained, children) = collectFunctions(u) | ||
(ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) | ||
case children => | ||
// There should not be any other UDFs, or the children can't be evaluated directly. | ||
assert(children.forall(!_.exists(_.isInstanceOf[PythonFuncExpression]))) | ||
(ChainedPythonFunctions(Seq(udf.func)), udf.children) | ||
} | ||
} | ||
|
||
// Helper functions | ||
/** | ||
* See [[WindowBoundHelpers]] for details. | ||
*/ | ||
private def computeWindowBoundHelpers( | ||
factories: Seq[InternalRow => WindowFunctionFrame]): WindowBoundHelpers = { | ||
val functionFrames = factories.map(_ (EmptyRow)) | ||
|
||
val windowBoundTypes = functionFrames.map { | ||
case _: UnboundedWindowFunctionFrame => UnboundedWindow | ||
case _: UnboundedFollowingWindowFunctionFrame | | ||
_: SlidingWindowFunctionFrame | | ||
_: UnboundedPrecedingWindowFunctionFrame => BoundedWindow | ||
// It should be impossible to get other types of window function frame here | ||
case frame => throw QueryExecutionErrors.unexpectedWindowFunctionFrameError(frame.toString) | ||
} | ||
|
||
val requiredIndices = functionFrames.map { | ||
case _: UnboundedWindowFunctionFrame => 0 | ||
case _ => 2 | ||
} | ||
|
||
val upperBoundIndices = requiredIndices.scan(0)(_ + _).tail | ||
|
||
val boundIndices = requiredIndices.zip(upperBoundIndices).map { case (num, upperBoundIndex) => | ||
if (num == 0) { | ||
// Sentinel values for unbounded window | ||
(-1, -1) | ||
} else { | ||
(upperBoundIndex - 2, upperBoundIndex - 1) | ||
} | ||
} | ||
|
||
def lowerBoundIndex(frameIndex: Int) = boundIndices(frameIndex)._1 | ||
|
||
def upperBoundIndex(frameIndex: Int) = boundIndices(frameIndex)._2 | ||
|
||
(requiredIndices.sum, lowerBoundIndex, upperBoundIndex, windowBoundTypes) | ||
} | ||
|
||
override def createEvaluator(): PartitionEvaluator[InternalRow, InternalRow] = { | ||
new WindowInPandasPartitionEvaluator() | ||
} | ||
|
||
class WindowInPandasPartitionEvaluator extends PartitionEvaluator[InternalRow, InternalRow] { | ||
private val conf: SQLConf = SQLConf.get | ||
|
||
// Unwrap the expressions and factories from the map. | ||
private val expressionsWithFrameIndex = | ||
windowFrameExpressionFactoryPairs.map(_._1).zipWithIndex.flatMap { | ||
case (buffer, frameIndex) => buffer.map(expr => (expr, frameIndex)) | ||
} | ||
|
||
private val expressions = expressionsWithFrameIndex.map(_._1) | ||
private val expressionIndexToFrameIndex = | ||
expressionsWithFrameIndex.map(_._2).zipWithIndex.map(_.swap).toMap | ||
|
||
private val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray | ||
|
||
private val (numBoundIndices, lowerBoundIndex, upperBoundIndex, frameWindowBoundTypes) = | ||
computeWindowBoundHelpers(factories) | ||
private val isBounded = { frameIndex: Int => lowerBoundIndex(frameIndex) >= 0 } | ||
private val numFrames = factories.length | ||
|
||
private val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold | ||
private val spillThreshold = conf.windowExecBufferSpillThreshold | ||
private val sessionLocalTimeZone = conf.sessionLocalTimeZone | ||
private val largeVarTypes = conf.arrowUseLargeVarTypes | ||
|
||
// Extract window expressions and window functions | ||
private val windowExpressions = expressions.flatMap(_.collect { case e: WindowExpression => e }) | ||
private val udfExpressions = windowExpressions.map { e => | ||
e.windowFunction.asInstanceOf[AggregateExpression].aggregateFunction.asInstanceOf[PythonUDAF] | ||
} | ||
|
||
// We shouldn't be chaining anything here. | ||
// All chained python functions should only contain one function. | ||
private val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip | ||
require(pyFuncs.length == expressions.length) | ||
|
||
private val udfWindowBoundTypes = pyFuncs.indices.map(i => | ||
frameWindowBoundTypes(expressionIndexToFrameIndex(i))) | ||
private val pythonRunnerConf: Map[String, String] = (ArrowUtils.getPythonRunnerConfMap(conf) | ||
+ (windowBoundTypeConf -> udfWindowBoundTypes.map(_.value).mkString(","))) | ||
|
||
// Filter child output attributes down to only those that are UDF inputs. | ||
// Also eliminate duplicate UDF inputs. This is similar to how other Python UDF node | ||
// handles UDF inputs. | ||
private val dataInputs = new ArrayBuffer[Expression] | ||
private val dataInputTypes = new ArrayBuffer[DataType] | ||
private val argOffsets = inputs.map { input => | ||
input.map { e => | ||
if (dataInputs.exists(_.semanticEquals(e))) { | ||
dataInputs.indexWhere(_.semanticEquals(e)) | ||
} else { | ||
dataInputs += e | ||
dataInputTypes += e.dataType | ||
dataInputs.length - 1 | ||
} | ||
}.toArray | ||
}.toArray | ||
|
||
// In addition to UDF inputs, we will prepend window bounds for each UDFs. | ||
// For bounded windows, we prepend lower bound and upper bound. For unbounded windows, | ||
// we no not add window bounds. (strictly speaking, we only need to lower or upper bound | ||
// if the window is bounded only on one side, this can be improved in the future) | ||
|
||
// Setting window bounds for each window frames. Each window frame has different bounds so | ||
// each has its own window bound columns. | ||
private val windowBoundsInput = factories.indices.flatMap { frameIndex => | ||
if (isBounded(frameIndex)) { | ||
Seq( | ||
BoundReference(lowerBoundIndex(frameIndex), IntegerType, nullable = false), | ||
BoundReference(upperBoundIndex(frameIndex), IntegerType, nullable = false) | ||
) | ||
} else { | ||
Seq.empty | ||
} | ||
} | ||
|
||
// Setting the window bounds argOffset for each UDF. For UDFs with bounded window, argOffset | ||
// for the UDF is (lowerBoundOffset, upperBoundOffset, inputOffset1, inputOffset2, ...) | ||
// For UDFs with unbounded window, argOffset is (inputOffset1, inputOffset2, ...) | ||
pyFuncs.indices.foreach { exprIndex => | ||
val frameIndex = expressionIndexToFrameIndex(exprIndex) | ||
if (isBounded(frameIndex)) { | ||
argOffsets(exprIndex) = | ||
Array(lowerBoundIndex(frameIndex), upperBoundIndex(frameIndex)) ++ | ||
argOffsets(exprIndex).map(_ + windowBoundsInput.length) | ||
} else { | ||
argOffsets(exprIndex) = argOffsets(exprIndex).map(_ + windowBoundsInput.length) | ||
} | ||
} | ||
|
||
private val allInputs = windowBoundsInput ++ dataInputs | ||
private val allInputTypes = allInputs.map(_.dataType) | ||
private val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) | ||
|
||
override def eval( | ||
partitionIndex: Int, | ||
inputs: Iterator[InternalRow]*): Iterator[InternalRow] = { | ||
val iter = inputs.head | ||
val context = TaskContext.get() | ||
|
||
// Get all relevant projections. | ||
val resultProj = createResultProjection(expressions) | ||
val pythonInputProj = UnsafeProjection.create( | ||
allInputs, | ||
windowBoundsInput.map(ref => | ||
AttributeReference(s"i_${ref.ordinal}", ref.dataType)()) ++ childOutput | ||
) | ||
val pythonInputSchema = StructType( | ||
allInputTypes.zipWithIndex.map { case (dt, i) => | ||
StructField(s"_$i", dt) | ||
} | ||
) | ||
val grouping = UnsafeProjection.create(partitionSpec, childOutput) | ||
|
||
// The queue used to buffer input rows so we can drain it to | ||
// combine input with output from Python. | ||
val queue = HybridRowQueue(context.taskMemoryManager(), | ||
new File(Utils.getLocalDir(SparkEnv.get.conf)), childOutput.length) | ||
context.addTaskCompletionListener[Unit] { _ => | ||
queue.close() | ||
} | ||
|
||
val stream = iter.map { row => | ||
queue.add(row.asInstanceOf[UnsafeRow]) | ||
row | ||
} | ||
|
||
val pythonInput = new Iterator[Iterator[UnsafeRow]] { | ||
|
||
// Manage the stream and the grouping. | ||
var nextRow: UnsafeRow = null | ||
var nextGroup: UnsafeRow = null | ||
var nextRowAvailable: Boolean = false | ||
|
||
private[this] def fetchNextRow(): Unit = { | ||
nextRowAvailable = stream.hasNext | ||
if (nextRowAvailable) { | ||
nextRow = stream.next().asInstanceOf[UnsafeRow] | ||
nextGroup = grouping(nextRow) | ||
} else { | ||
nextRow = null | ||
nextGroup = null | ||
} | ||
} | ||
|
||
fetchNextRow() | ||
|
||
// Manage the current partition. | ||
val buffer: ExternalAppendOnlyUnsafeRowArray = | ||
new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) | ||
var bufferIterator: Iterator[UnsafeRow] = _ | ||
|
||
val indexRow = new SpecificInternalRow(Array.fill(numBoundIndices)(IntegerType)) | ||
|
||
val frames = factories.map(_ (indexRow)) | ||
|
||
private[this] def fetchNextPartition(): Unit = { | ||
// Collect all the rows in the current partition. | ||
// Before we start to fetch new input rows, make a copy of nextGroup. | ||
val currentGroup = nextGroup.copy() | ||
|
||
// clear last partition | ||
buffer.clear() | ||
|
||
while (nextRowAvailable && nextGroup == currentGroup) { | ||
buffer.add(nextRow) | ||
fetchNextRow() | ||
} | ||
|
||
// Setup the frames. | ||
var i = 0 | ||
while (i < numFrames) { | ||
frames(i).prepare(buffer) | ||
i += 1 | ||
} | ||
|
||
// Setup iteration | ||
rowIndex = 0 | ||
bufferIterator = buffer.generateIterator() | ||
} | ||
|
||
// Iteration | ||
var rowIndex = 0 | ||
|
||
override final def hasNext: Boolean = { | ||
val found = (bufferIterator != null && bufferIterator.hasNext) || nextRowAvailable | ||
if (!found) { | ||
// clear final partition | ||
buffer.clear() | ||
spillSize += buffer.spillSize | ||
} | ||
found | ||
} | ||
|
||
override final def next(): Iterator[UnsafeRow] = { | ||
// Load the next partition if we need to. | ||
if ((bufferIterator == null || !bufferIterator.hasNext) && nextRowAvailable) { | ||
fetchNextPartition() | ||
} | ||
|
||
val join = new JoinedRow | ||
|
||
bufferIterator.zipWithIndex.map { | ||
case (current, index) => | ||
var frameIndex = 0 | ||
while (frameIndex < numFrames) { | ||
frames(frameIndex).write(index, current) | ||
// If the window is unbounded we don't need to write out window bounds. | ||
if (isBounded(frameIndex)) { | ||
indexRow.setInt( | ||
lowerBoundIndex(frameIndex), frames(frameIndex).currentLowerBound()) | ||
indexRow.setInt( | ||
upperBoundIndex(frameIndex), frames(frameIndex).currentUpperBound()) | ||
} | ||
frameIndex += 1 | ||
} | ||
|
||
pythonInputProj(join(indexRow, current)) | ||
} | ||
} | ||
} | ||
|
||
val windowFunctionResult = new ArrowPythonRunner( | ||
pyFuncs, | ||
PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, | ||
argOffsets, | ||
pythonInputSchema, | ||
sessionLocalTimeZone, | ||
largeVarTypes, | ||
pythonRunnerConf, | ||
pythonMetrics, | ||
jobArtifactUUID).compute(pythonInput, context.partitionId(), context) | ||
|
||
val joined = new JoinedRow | ||
|
||
windowFunctionResult.flatMap(_.rowIterator.asScala).map { windowOutput => | ||
val leftRow = queue.remove() | ||
val joinedRow = joined(leftRow, windowOutput) | ||
resultProj(joinedRow) | ||
} | ||
} | ||
} | ||
} |
Oops, something went wrong.