From 95dc829c92ae7cac51662895e2c1e79ada7ff4e6 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Tue, 18 Jul 2023 11:37:14 +0800 Subject: [PATCH] [SPARK-44341][SQL][PYTHON] Define the computing logic through PartitionEvaluator API and use it in WindowExec and WindowInPandasExec ### What changes were proposed in this pull request? `WindowExec` and `WindowInPandasExec` are updated to use the `PartitionEvaluator` API to do execution. ### Why are the changes needed? To define the computing logic and requires the caller side to explicitly list what needs to be serialized and sent to executors ### Does this PR introduce _any_ user-facing change? 'No'. Just update the inner implementation. ### How was this patch tested? Add new test cases. Closes #41939 from beliefer/SPARK-44341. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../WindowInPandasEvaluatorFactory.scala | 369 ++++++++++++++++ .../execution/python/WindowInPandasExec.scala | 332 +------------- .../window/WindowEvaluatorFactory.scala | 417 ++++++++++++++++++ .../sql/execution/window/WindowExec.scala | 118 +---- .../sql/execution/window/WindowExecBase.scala | 257 ----------- 5 files changed, 817 insertions(+), 676 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala new file mode 100644 index 0000000000000..364e94ab158e9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala @@ -0,0 +1,369 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.File + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{JobArtifactSet, PartitionEvaluator, PartitionEvaluatorFactory, SparkEnv, TaskContext} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, EmptyRow, Expression, JoinedRow, NamedExpression, PythonFuncExpression, PythonUDAF, SortOrder, SpecificInternalRow, UnsafeProjection, UnsafeRow, WindowExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.window.{SlidingWindowFunctionFrame, UnboundedFollowingWindowFunctionFrame, UnboundedPrecedingWindowFunctionFrame, UnboundedWindowFunctionFrame, WindowEvaluatorFactoryBase, WindowFunctionFrame} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructType} +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.util.Utils + +class WindowInPandasEvaluatorFactory( + val windowExpression: Seq[NamedExpression], + val partitionSpec: Seq[Expression], + val orderSpec: Seq[SortOrder], + val childOutput: Seq[Attribute], + val spillSize: SQLMetric, + pythonMetrics: Map[String, SQLMetric]) + extends PartitionEvaluatorFactory[InternalRow, InternalRow] with WindowEvaluatorFactoryBase { + + /** + * Helper functions and data structures for window bounds + * + * It contains: + * (1) Total number of window bound indices in the python input row + * (2) Function from frame index to its lower bound column index in the python input row + * (3) Function from frame index to its upper bound column index in the python input row + * (4) Seq from frame index to its window bound type + */ + private type WindowBoundHelpers = (Int, Int => Int, Int => Int, Seq[WindowBoundType]) + + /** + * Enum for window bound types. Used only inside this class. + */ + private sealed case class WindowBoundType(value: String) + + private object UnboundedWindow extends WindowBoundType("unbounded") + + private object BoundedWindow extends WindowBoundType("bounded") + + private val windowBoundTypeConf = "pandas_window_bound_types" + + private def collectFunctions( + udf: PythonFuncExpression): (ChainedPythonFunctions, Seq[Expression]) = { + udf.children match { + case Seq(u: PythonFuncExpression) => + val (chained, children) = collectFunctions(u) + (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) + case children => + // There should not be any other UDFs, or the children can't be evaluated directly. + assert(children.forall(!_.exists(_.isInstanceOf[PythonFuncExpression]))) + (ChainedPythonFunctions(Seq(udf.func)), udf.children) + } + } + + // Helper functions + /** + * See [[WindowBoundHelpers]] for details. + */ + private def computeWindowBoundHelpers( + factories: Seq[InternalRow => WindowFunctionFrame]): WindowBoundHelpers = { + val functionFrames = factories.map(_ (EmptyRow)) + + val windowBoundTypes = functionFrames.map { + case _: UnboundedWindowFunctionFrame => UnboundedWindow + case _: UnboundedFollowingWindowFunctionFrame | + _: SlidingWindowFunctionFrame | + _: UnboundedPrecedingWindowFunctionFrame => BoundedWindow + // It should be impossible to get other types of window function frame here + case frame => throw QueryExecutionErrors.unexpectedWindowFunctionFrameError(frame.toString) + } + + val requiredIndices = functionFrames.map { + case _: UnboundedWindowFunctionFrame => 0 + case _ => 2 + } + + val upperBoundIndices = requiredIndices.scan(0)(_ + _).tail + + val boundIndices = requiredIndices.zip(upperBoundIndices).map { case (num, upperBoundIndex) => + if (num == 0) { + // Sentinel values for unbounded window + (-1, -1) + } else { + (upperBoundIndex - 2, upperBoundIndex - 1) + } + } + + def lowerBoundIndex(frameIndex: Int) = boundIndices(frameIndex)._1 + + def upperBoundIndex(frameIndex: Int) = boundIndices(frameIndex)._2 + + (requiredIndices.sum, lowerBoundIndex, upperBoundIndex, windowBoundTypes) + } + + override def createEvaluator(): PartitionEvaluator[InternalRow, InternalRow] = { + new WindowInPandasPartitionEvaluator() + } + + class WindowInPandasPartitionEvaluator extends PartitionEvaluator[InternalRow, InternalRow] { + private val conf: SQLConf = SQLConf.get + + // Unwrap the expressions and factories from the map. + private val expressionsWithFrameIndex = + windowFrameExpressionFactoryPairs.map(_._1).zipWithIndex.flatMap { + case (buffer, frameIndex) => buffer.map(expr => (expr, frameIndex)) + } + + private val expressions = expressionsWithFrameIndex.map(_._1) + private val expressionIndexToFrameIndex = + expressionsWithFrameIndex.map(_._2).zipWithIndex.map(_.swap).toMap + + private val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray + + private val (numBoundIndices, lowerBoundIndex, upperBoundIndex, frameWindowBoundTypes) = + computeWindowBoundHelpers(factories) + private val isBounded = { frameIndex: Int => lowerBoundIndex(frameIndex) >= 0 } + private val numFrames = factories.length + + private val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold + private val spillThreshold = conf.windowExecBufferSpillThreshold + private val sessionLocalTimeZone = conf.sessionLocalTimeZone + private val largeVarTypes = conf.arrowUseLargeVarTypes + + // Extract window expressions and window functions + private val windowExpressions = expressions.flatMap(_.collect { case e: WindowExpression => e }) + private val udfExpressions = windowExpressions.map { e => + e.windowFunction.asInstanceOf[AggregateExpression].aggregateFunction.asInstanceOf[PythonUDAF] + } + + // We shouldn't be chaining anything here. + // All chained python functions should only contain one function. + private val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip + require(pyFuncs.length == expressions.length) + + private val udfWindowBoundTypes = pyFuncs.indices.map(i => + frameWindowBoundTypes(expressionIndexToFrameIndex(i))) + private val pythonRunnerConf: Map[String, String] = (ArrowUtils.getPythonRunnerConfMap(conf) + + (windowBoundTypeConf -> udfWindowBoundTypes.map(_.value).mkString(","))) + + // Filter child output attributes down to only those that are UDF inputs. + // Also eliminate duplicate UDF inputs. This is similar to how other Python UDF node + // handles UDF inputs. + private val dataInputs = new ArrayBuffer[Expression] + private val dataInputTypes = new ArrayBuffer[DataType] + private val argOffsets = inputs.map { input => + input.map { e => + if (dataInputs.exists(_.semanticEquals(e))) { + dataInputs.indexWhere(_.semanticEquals(e)) + } else { + dataInputs += e + dataInputTypes += e.dataType + dataInputs.length - 1 + } + }.toArray + }.toArray + + // In addition to UDF inputs, we will prepend window bounds for each UDFs. + // For bounded windows, we prepend lower bound and upper bound. For unbounded windows, + // we no not add window bounds. (strictly speaking, we only need to lower or upper bound + // if the window is bounded only on one side, this can be improved in the future) + + // Setting window bounds for each window frames. Each window frame has different bounds so + // each has its own window bound columns. + private val windowBoundsInput = factories.indices.flatMap { frameIndex => + if (isBounded(frameIndex)) { + Seq( + BoundReference(lowerBoundIndex(frameIndex), IntegerType, nullable = false), + BoundReference(upperBoundIndex(frameIndex), IntegerType, nullable = false) + ) + } else { + Seq.empty + } + } + + // Setting the window bounds argOffset for each UDF. For UDFs with bounded window, argOffset + // for the UDF is (lowerBoundOffset, upperBoundOffset, inputOffset1, inputOffset2, ...) + // For UDFs with unbounded window, argOffset is (inputOffset1, inputOffset2, ...) + pyFuncs.indices.foreach { exprIndex => + val frameIndex = expressionIndexToFrameIndex(exprIndex) + if (isBounded(frameIndex)) { + argOffsets(exprIndex) = + Array(lowerBoundIndex(frameIndex), upperBoundIndex(frameIndex)) ++ + argOffsets(exprIndex).map(_ + windowBoundsInput.length) + } else { + argOffsets(exprIndex) = argOffsets(exprIndex).map(_ + windowBoundsInput.length) + } + } + + private val allInputs = windowBoundsInput ++ dataInputs + private val allInputTypes = allInputs.map(_.dataType) + private val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + + override def eval( + partitionIndex: Int, + inputs: Iterator[InternalRow]*): Iterator[InternalRow] = { + val iter = inputs.head + val context = TaskContext.get() + + // Get all relevant projections. + val resultProj = createResultProjection(expressions) + val pythonInputProj = UnsafeProjection.create( + allInputs, + windowBoundsInput.map(ref => + AttributeReference(s"i_${ref.ordinal}", ref.dataType)()) ++ childOutput + ) + val pythonInputSchema = StructType( + allInputTypes.zipWithIndex.map { case (dt, i) => + StructField(s"_$i", dt) + } + ) + val grouping = UnsafeProjection.create(partitionSpec, childOutput) + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = HybridRowQueue(context.taskMemoryManager(), + new File(Utils.getLocalDir(SparkEnv.get.conf)), childOutput.length) + context.addTaskCompletionListener[Unit] { _ => + queue.close() + } + + val stream = iter.map { row => + queue.add(row.asInstanceOf[UnsafeRow]) + row + } + + val pythonInput = new Iterator[Iterator[UnsafeRow]] { + + // Manage the stream and the grouping. + var nextRow: UnsafeRow = null + var nextGroup: UnsafeRow = null + var nextRowAvailable: Boolean = false + + private[this] def fetchNextRow(): Unit = { + nextRowAvailable = stream.hasNext + if (nextRowAvailable) { + nextRow = stream.next().asInstanceOf[UnsafeRow] + nextGroup = grouping(nextRow) + } else { + nextRow = null + nextGroup = null + } + } + + fetchNextRow() + + // Manage the current partition. + val buffer: ExternalAppendOnlyUnsafeRowArray = + new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) + var bufferIterator: Iterator[UnsafeRow] = _ + + val indexRow = new SpecificInternalRow(Array.fill(numBoundIndices)(IntegerType)) + + val frames = factories.map(_ (indexRow)) + + private[this] def fetchNextPartition(): Unit = { + // Collect all the rows in the current partition. + // Before we start to fetch new input rows, make a copy of nextGroup. + val currentGroup = nextGroup.copy() + + // clear last partition + buffer.clear() + + while (nextRowAvailable && nextGroup == currentGroup) { + buffer.add(nextRow) + fetchNextRow() + } + + // Setup the frames. + var i = 0 + while (i < numFrames) { + frames(i).prepare(buffer) + i += 1 + } + + // Setup iteration + rowIndex = 0 + bufferIterator = buffer.generateIterator() + } + + // Iteration + var rowIndex = 0 + + override final def hasNext: Boolean = { + val found = (bufferIterator != null && bufferIterator.hasNext) || nextRowAvailable + if (!found) { + // clear final partition + buffer.clear() + spillSize += buffer.spillSize + } + found + } + + override final def next(): Iterator[UnsafeRow] = { + // Load the next partition if we need to. + if ((bufferIterator == null || !bufferIterator.hasNext) && nextRowAvailable) { + fetchNextPartition() + } + + val join = new JoinedRow + + bufferIterator.zipWithIndex.map { + case (current, index) => + var frameIndex = 0 + while (frameIndex < numFrames) { + frames(frameIndex).write(index, current) + // If the window is unbounded we don't need to write out window bounds. + if (isBounded(frameIndex)) { + indexRow.setInt( + lowerBoundIndex(frameIndex), frames(frameIndex).currentLowerBound()) + indexRow.setInt( + upperBoundIndex(frameIndex), frames(frameIndex).currentUpperBound()) + } + frameIndex += 1 + } + + pythonInputProj(join(indexRow, current)) + } + } + } + + val windowFunctionResult = new ArrowPythonRunner( + pyFuncs, + PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, + argOffsets, + pythonInputSchema, + sessionLocalTimeZone, + largeVarTypes, + pythonRunnerConf, + pythonMetrics, + jobArtifactUUID).compute(pythonInput, context.partitionId(), context) + + val joined = new JoinedRow + + windowFunctionResult.flatMap(_.rowIterator.asScala).map { windowOutput => + val leftRow = queue.remove() + val joinedRow = joined(leftRow, windowOutput) + resultProj(joinedRow) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala index 3d43c417dcb23..ba1f2c132ff5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -17,24 +17,12 @@ package org.apache.spark.sql.execution.python -import java.io.File - -import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.{JobArtifactSet, SparkEnv, TaskContext} -import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.window._ -import org.apache.spark.sql.types._ -import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.util.Utils /** * This class calculates and outputs windowed aggregates over the rows in a single partition. @@ -91,313 +79,25 @@ case class WindowInPandasExec( "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size") ) - /** - * Helper functions and data structures for window bounds - * - * It contains: - * (1) Total number of window bound indices in the python input row - * (2) Function from frame index to its lower bound column index in the python input row - * (3) Function from frame index to its upper bound column index in the python input row - * (4) Seq from frame index to its window bound type - */ - private type WindowBoundHelpers = (Int, Int => Int, Int => Int, Seq[WindowBoundType]) - - /** - * Enum for window bound types. Used only inside this class. - */ - private sealed case class WindowBoundType(value: String) - private object UnboundedWindow extends WindowBoundType("unbounded") - private object BoundedWindow extends WindowBoundType("bounded") - - private val windowBoundTypeConf = "pandas_window_bound_types" - - private def collectFunctions( - udf: PythonFuncExpression): (ChainedPythonFunctions, Seq[Expression]) = { - udf.children match { - case Seq(u: PythonFuncExpression) => - val (chained, children) = collectFunctions(u) - (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) - case children => - // There should not be any other UDFs, or the children can't be evaluated directly. - assert(children.forall(!_.exists(_.isInstanceOf[PythonFuncExpression]))) - (ChainedPythonFunctions(Seq(udf.func)), udf.children) - } - } - - /** - * See [[WindowBoundHelpers]] for details. - */ - private def computeWindowBoundHelpers( - factories: Seq[InternalRow => WindowFunctionFrame] - ): WindowBoundHelpers = { - val functionFrames = factories.map(_(EmptyRow)) - - val windowBoundTypes = functionFrames.map { - case _: UnboundedWindowFunctionFrame => UnboundedWindow - case _: UnboundedFollowingWindowFunctionFrame | - _: SlidingWindowFunctionFrame | - _: UnboundedPrecedingWindowFunctionFrame => BoundedWindow - // It should be impossible to get other types of window function frame here - case frame => throw QueryExecutionErrors.unexpectedWindowFunctionFrameError(frame.toString) - } - - val requiredIndices = functionFrames.map { - case _: UnboundedWindowFunctionFrame => 0 - case _ => 2 - } - - val upperBoundIndices = requiredIndices.scan(0)(_ + _).tail - - val boundIndices = requiredIndices.zip(upperBoundIndices).map { case (num, upperBoundIndex) => - if (num == 0) { - // Sentinel values for unbounded window - (-1, -1) - } else { - (upperBoundIndex - 2, upperBoundIndex - 1) - } - } - - def lowerBoundIndex(frameIndex: Int) = boundIndices(frameIndex)._1 - def upperBoundIndex(frameIndex: Int) = boundIndices(frameIndex)._2 - - (requiredIndices.sum, lowerBoundIndex, upperBoundIndex, windowBoundTypes) - } - protected override def doExecute(): RDD[InternalRow] = { - // Unwrap the expressions and factories from the map. - val expressionsWithFrameIndex = - windowFrameExpressionFactoryPairs.map(_._1).zipWithIndex.flatMap { - case (buffer, frameIndex) => buffer.map(expr => (expr, frameIndex)) - } - - val expressions = expressionsWithFrameIndex.map(_._1) - val expressionIndexToFrameIndex = - expressionsWithFrameIndex.map(_._2).zipWithIndex.map(_.swap).toMap - - val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray - - // Helper functions - val (numBoundIndices, lowerBoundIndex, upperBoundIndex, frameWindowBoundTypes) = - computeWindowBoundHelpers(factories) - val isBounded = { frameIndex: Int => lowerBoundIndex(frameIndex) >= 0 } - val numFrames = factories.length - - val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold - val spillThreshold = conf.windowExecBufferSpillThreshold - val sessionLocalTimeZone = conf.sessionLocalTimeZone - val largeVarTypes = conf.arrowUseLargeVarTypes - - // Extract window expressions and window functions - val windowExpressions = expressions.flatMap(_.collect { case e: WindowExpression => e }) - val udfExpressions = windowExpressions.map { e => - e.windowFunction.asInstanceOf[AggregateExpression].aggregateFunction.asInstanceOf[PythonUDAF] - } - - // We shouldn't be chaining anything here. - // All chained python functions should only contain one function. - val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip - require(pyFuncs.length == expressions.length) - - val udfWindowBoundTypes = pyFuncs.indices.map(i => - frameWindowBoundTypes(expressionIndexToFrameIndex(i))) - val pythonRunnerConf: Map[String, String] = (ArrowUtils.getPythonRunnerConfMap(conf) - + (windowBoundTypeConf -> udfWindowBoundTypes.map(_.value).mkString(","))) - - // Filter child output attributes down to only those that are UDF inputs. - // Also eliminate duplicate UDF inputs. This is similar to how other Python UDF node - // handles UDF inputs. - val dataInputs = new ArrayBuffer[Expression] - val dataInputTypes = new ArrayBuffer[DataType] - val argOffsets = inputs.map { input => - input.map { e => - if (dataInputs.exists(_.semanticEquals(e))) { - dataInputs.indexWhere(_.semanticEquals(e)) - } else { - dataInputs += e - dataInputTypes += e.dataType - dataInputs.length - 1 - } - }.toArray - }.toArray - - // In addition to UDF inputs, we will prepend window bounds for each UDFs. - // For bounded windows, we prepend lower bound and upper bound. For unbounded windows, - // we no not add window bounds. (strictly speaking, we only need to lower or upper bound - // if the window is bounded only on one side, this can be improved in the future) - - // Setting window bounds for each window frames. Each window frame has different bounds so - // each has its own window bound columns. - val windowBoundsInput = factories.indices.flatMap { frameIndex => - if (isBounded(frameIndex)) { - Seq( - BoundReference(lowerBoundIndex(frameIndex), IntegerType, nullable = false), - BoundReference(upperBoundIndex(frameIndex), IntegerType, nullable = false) - ) - } else { - Seq.empty - } - } - - // Setting the window bounds argOffset for each UDF. For UDFs with bounded window, argOffset - // for the UDF is (lowerBoundOffset, upperBoundOffset, inputOffset1, inputOffset2, ...) - // For UDFs with unbounded window, argOffset is (inputOffset1, inputOffset2, ...) - pyFuncs.indices.foreach { exprIndex => - val frameIndex = expressionIndexToFrameIndex(exprIndex) - if (isBounded(frameIndex)) { - argOffsets(exprIndex) = - Array(lowerBoundIndex(frameIndex), upperBoundIndex(frameIndex)) ++ - argOffsets(exprIndex).map(_ + windowBoundsInput.length) - } else { - argOffsets(exprIndex) = argOffsets(exprIndex).map(_ + windowBoundsInput.length) - } - } - - val allInputs = windowBoundsInput ++ dataInputs - val allInputTypes = allInputs.map(_.dataType) val spillSize = longMetric("spillSize") - val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) - - // Start processing. - child.execute().mapPartitions { iter => - val context = TaskContext.get() - - // Get all relevant projections. - val resultProj = createResultProjection(expressions) - val pythonInputProj = UnsafeProjection.create( - allInputs, - windowBoundsInput.map(ref => - AttributeReference(s"i_${ref.ordinal}", ref.dataType)()) ++ child.output - ) - val pythonInputSchema = StructType( - allInputTypes.zipWithIndex.map { case (dt, i) => - StructField(s"_$i", dt) - } - ) - val grouping = UnsafeProjection.create(partitionSpec, child.output) - - // The queue used to buffer input rows so we can drain it to - // combine input with output from Python. - val queue = HybridRowQueue(context.taskMemoryManager(), - new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) - context.addTaskCompletionListener[Unit] { _ => - queue.close() - } - - val stream = iter.map { row => - queue.add(row.asInstanceOf[UnsafeRow]) - row - } - - val pythonInput = new Iterator[Iterator[UnsafeRow]] { - - // Manage the stream and the grouping. - var nextRow: UnsafeRow = null - var nextGroup: UnsafeRow = null - var nextRowAvailable: Boolean = false - private[this] def fetchNextRow(): Unit = { - nextRowAvailable = stream.hasNext - if (nextRowAvailable) { - nextRow = stream.next().asInstanceOf[UnsafeRow] - nextGroup = grouping(nextRow) - } else { - nextRow = null - nextGroup = null - } - } - fetchNextRow() - - // Manage the current partition. - val buffer: ExternalAppendOnlyUnsafeRowArray = - new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) - var bufferIterator: Iterator[UnsafeRow] = _ - - val indexRow = new SpecificInternalRow(Array.fill(numBoundIndices)(IntegerType)) - val frames = factories.map(_(indexRow)) + val evaluatorFactory = + new WindowInPandasEvaluatorFactory( + windowExpression, + partitionSpec, + orderSpec, + child.output, + spillSize, + pythonMetrics) - private[this] def fetchNextPartition(): Unit = { - // Collect all the rows in the current partition. - // Before we start to fetch new input rows, make a copy of nextGroup. - val currentGroup = nextGroup.copy() - - // clear last partition - buffer.clear() - - while (nextRowAvailable && nextGroup == currentGroup) { - buffer.add(nextRow) - fetchNextRow() - } - - // Setup the frames. - var i = 0 - while (i < numFrames) { - frames(i).prepare(buffer) - i += 1 - } - - // Setup iteration - rowIndex = 0 - bufferIterator = buffer.generateIterator() - } - - // Iteration - var rowIndex = 0 - - override final def hasNext: Boolean = { - val found = (bufferIterator != null && bufferIterator.hasNext) || nextRowAvailable - if (!found) { - // clear final partition - buffer.clear() - spillSize += buffer.spillSize - } - found - } - - override final def next(): Iterator[UnsafeRow] = { - // Load the next partition if we need to. - if ((bufferIterator == null || !bufferIterator.hasNext) && nextRowAvailable) { - fetchNextPartition() - } - - val join = new JoinedRow - - bufferIterator.zipWithIndex.map { - case (current, index) => - var frameIndex = 0 - while (frameIndex < numFrames) { - frames(frameIndex).write(index, current) - // If the window is unbounded we don't need to write out window bounds. - if (isBounded(frameIndex)) { - indexRow.setInt( - lowerBoundIndex(frameIndex), frames(frameIndex).currentLowerBound()) - indexRow.setInt( - upperBoundIndex(frameIndex), frames(frameIndex).currentUpperBound()) - } - frameIndex += 1 - } - - pythonInputProj(join(indexRow, current)) - } - } - } - - val windowFunctionResult = new ArrowPythonRunner( - pyFuncs, - PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, - argOffsets, - pythonInputSchema, - sessionLocalTimeZone, - largeVarTypes, - pythonRunnerConf, - pythonMetrics, - jobArtifactUUID).compute(pythonInput, context.partitionId(), context) - - val joined = new JoinedRow - - windowFunctionResult.flatMap(_.rowIterator.asScala).map { windowOutput => - val leftRow = queue.remove() - val joinedRow = joined(leftRow, windowOutput) - resultProj(joinedRow) + // Start processing. + if (conf.usePartitionEvaluator) { + child.execute().mapPartitionsWithEvaluator(evaluatorFactory) + } else { + child.execute().mapPartitions { iter => + val evaluator = evaluatorFactory.createEvaluator() + evaluator.eval(0, iter) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala new file mode 100644 index 0000000000000..913f8762c7953 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala @@ -0,0 +1,417 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.window + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{PartitionEvaluator, PartitionEvaluatorFactory} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Add, AggregateWindowFunction, Ascending, Attribute, BoundReference, CurrentRow, DateAdd, DateAddYMInterval, DecimalAddNoOverflowCheck, Descending, Expression, FrameLessOffsetWindowFunction, FrameType, IdentityProjection, IntegerLiteral, JoinedRow, MutableProjection, NamedExpression, OffsetWindowFunction, PythonFuncExpression, RangeFrame, RowFrame, RowOrdering, SortOrder, SpecificInternalRow, SpecifiedWindowFrame, TimeAdd, TimestampAddYMInterval, UnaryMinus, UnboundedFollowing, UnboundedPreceding, UnsafeProjection, UnsafeRow, WindowExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{CalendarIntervalType, DateType, DayTimeIntervalType, DecimalType, IntegerType, TimestampNTZType, TimestampType, YearMonthIntervalType} +import org.apache.spark.util.collection.Utils + +trait WindowEvaluatorFactoryBase { + def windowExpression: Seq[NamedExpression] + def partitionSpec: Seq[Expression] + def orderSpec: Seq[SortOrder] + def childOutput: Seq[Attribute] + def spillSize: SQLMetric + + /** + * Create the resulting projection. + * + * This method uses Code Generation. It can only be used on the executor side. + * + * @param expressions unbound ordered function expressions. + * @return the final resulting projection. + */ + protected def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = { + val references = expressions.zipWithIndex.map { case (e, i) => + // Results of window expressions will be on the right side of child's output + BoundReference(childOutput.size + i, e.dataType, e.nullable) + } + val unboundToRefMap = Utils.toMap(expressions, references) + val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) + UnsafeProjection.create( + childOutput ++ patchedWindowExpression, + childOutput) + } + + /** + * Create a bound ordering object for a given frame type and offset. A bound ordering object is + * used to determine which input row lies within the frame boundaries of an output row. + * + * This method uses Code Generation. It can only be used on the executor side. + * + * @param frame to evaluate. This can either be a Row or Range frame. + * @param bound with respect to the row. + * @param timeZone the session local timezone for time related calculations. + * @return a bound ordering object. + */ + private def createBoundOrdering( + frame: FrameType, bound: Expression, timeZone: String): BoundOrdering = { + (frame, bound) match { + case (RowFrame, CurrentRow) => + RowBoundOrdering(0) + + case (RowFrame, IntegerLiteral(offset)) => + RowBoundOrdering(offset) + + case (RowFrame, _) => + throw new IllegalStateException(s"Unhandled bound in windows expressions: $bound") + + case (RangeFrame, CurrentRow) => + val ordering = RowOrdering.create(orderSpec, childOutput) + RangeBoundOrdering(ordering, IdentityProjection, IdentityProjection) + + case (RangeFrame, offset: Expression) if orderSpec.size == 1 => + // Use only the first order expression when the offset is non-null. + val sortExpr = orderSpec.head + val expr = sortExpr.child + + // Create the projection which returns the current 'value'. + val current = MutableProjection.create(expr :: Nil, childOutput) + + // Flip the sign of the offset when processing the order is descending + val boundOffset = sortExpr.direction match { + case Descending => UnaryMinus(offset) + case Ascending => offset + } + + // Create the projection which returns the current 'value' modified by adding the offset. + val boundExpr = (expr.dataType, boundOffset.dataType) match { + case (DateType, IntegerType) => DateAdd(expr, boundOffset) + case (DateType, _: YearMonthIntervalType) => DateAddYMInterval(expr, boundOffset) + case (TimestampType | TimestampNTZType, CalendarIntervalType) => + TimeAdd(expr, boundOffset, Some(timeZone)) + case (TimestampType | TimestampNTZType, _: YearMonthIntervalType) => + TimestampAddYMInterval(expr, boundOffset, Some(timeZone)) + case (TimestampType | TimestampNTZType, _: DayTimeIntervalType) => + TimeAdd(expr, boundOffset, Some(timeZone)) + case (d: DecimalType, _: DecimalType) => DecimalAddNoOverflowCheck(expr, boundOffset, d) + case (a, b) if a == b => Add(expr, boundOffset) + } + val bound = MutableProjection.create(boundExpr :: Nil, childOutput) + + // Construct the ordering. This is used to compare the result of current value projection + // to the result of bound value projection. This is done manually because we want to use + // Code Generation (if it is enabled). + val boundSortExprs = sortExpr.copy(BoundReference(0, expr.dataType, expr.nullable)) :: Nil + val ordering = RowOrdering.create(boundSortExprs, Nil) + RangeBoundOrdering(ordering, current, bound) + + case (RangeFrame, _) => + throw new IllegalStateException("Non-Zero range offsets are not supported for windows " + + "with multiple order expressions.") + } + } + + /** + * Collection containing an entry for each window frame to process. Each entry contains a frame's + * [[WindowExpression]]s and factory function for the [[WindowFunctionFrame]]. + */ + protected lazy val windowFrameExpressionFactoryPairs = { + type FrameKey = (String, FrameType, Expression, Expression, Seq[Expression]) + type ExpressionBuffer = mutable.Buffer[Expression] + val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)] + + // Add a function and its function to the map for a given frame. + def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = { + val key = fn match { + // This branch is used for Lead/Lag to support ignoring null and optimize the performance + // for NthValue ignoring null. + // All window frames move in rows. If there are multiple Leads, Lags or NthValues acting on + // a row and operating on different input expressions, they should not be moved uniformly + // by row. Therefore, we put these functions in different window frames. + case f: OffsetWindowFunction if f.ignoreNulls => + (tpe, fr.frameType, fr.lower, fr.upper, f.children.map(_.canonicalized)) + case _ => (tpe, fr.frameType, fr.lower, fr.upper, Nil) + } + val (es, fns) = framedFunctions.getOrElseUpdate( + key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression])) + es += e + fns += fn + } + + // Collect all valid window functions and group them by their frame. + windowExpression.foreach { x => + x.foreach { + case e@WindowExpression(function, spec) => + val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] + function match { + case AggregateExpression(f, _, _, _, _) => collect("AGGREGATE", frame, e, f) + case f: FrameLessOffsetWindowFunction => + collect("FRAME_LESS_OFFSET", f.fakeFrame, e, f) + case f: OffsetWindowFunction if frame.frameType == RowFrame && + frame.lower == UnboundedPreceding => + frame.upper match { + case UnboundedFollowing => collect("UNBOUNDED_OFFSET", f.fakeFrame, e, f) + case CurrentRow => collect("UNBOUNDED_PRECEDING_OFFSET", f.fakeFrame, e, f) + case _ => collect("AGGREGATE", frame, e, f) + } + case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f) + case f => throw new IllegalStateException(s"Unsupported window function: $f") + } + case _ => + } + } + + // Map the groups to a (unbound) expression and frame factory pair. + var numExpressions = 0 + val timeZone = SQLConf.get.sessionLocalTimeZone + framedFunctions.toSeq.map { + case (key, (expressions, functionSeq)) => + val ordinal = numExpressions + val functions = functionSeq.toArray + + // Construct an aggregate processor if we need one. + // Currently we don't allow mixing of Pandas UDF and SQL aggregation functions + // in a single Window physical node. Therefore, we can assume no SQL aggregation + // functions if Pandas UDF exists. In the future, we might mix Pandas UDF and SQL + // aggregation function in a single physical node. + def processor = if (functions.exists(_.isInstanceOf[PythonFuncExpression])) { + null + } else { + AggregateProcessor( + functions, + ordinal, + childOutput, + (expressions, schema) => + MutableProjection.create(expressions, schema)) + } + + // Create the factory to produce WindowFunctionFrame. + val factory = key match { + // Frameless offset Frame + case ("FRAME_LESS_OFFSET", _, IntegerLiteral(offset), _, expr) => + target: InternalRow => + new FrameLessOffsetWindowFunctionFrame( + target, + ordinal, + // OFFSET frame functions are guaranteed be OffsetWindowFunction. + functions.map(_.asInstanceOf[OffsetWindowFunction]), + childOutput, + (expressions, schema) => + MutableProjection.create(expressions, schema), + offset, + expr.nonEmpty) + case ("UNBOUNDED_OFFSET", _, IntegerLiteral(offset), _, expr) => + target: InternalRow => { + new UnboundedOffsetWindowFunctionFrame( + target, + ordinal, + // OFFSET frame functions are guaranteed be OffsetWindowFunction. + functions.map(_.asInstanceOf[OffsetWindowFunction]), + childOutput, + (expressions, schema) => + MutableProjection.create(expressions, schema), + offset, + expr.nonEmpty) + } + case ("UNBOUNDED_PRECEDING_OFFSET", _, IntegerLiteral(offset), _, expr) => + target: InternalRow => { + new UnboundedPrecedingOffsetWindowFunctionFrame( + target, + ordinal, + // OFFSET frame functions are guaranteed be OffsetWindowFunction. + functions.map(_.asInstanceOf[OffsetWindowFunction]), + childOutput, + (expressions, schema) => + MutableProjection.create(expressions, schema), + offset, + expr.nonEmpty) + } + + // Entire Partition Frame. + case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing, _) => + target: InternalRow => { + new UnboundedWindowFunctionFrame(target, processor) + } + + // Growing Frame. + case ("AGGREGATE", frameType, UnboundedPreceding, upper, _) => + target: InternalRow => { + new UnboundedPrecedingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, upper, timeZone)) + } + + // Shrinking Frame. + case ("AGGREGATE", frameType, lower, UnboundedFollowing, _) => + target: InternalRow => { + new UnboundedFollowingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, lower, timeZone)) + } + + // Moving Frame. + case ("AGGREGATE", frameType, lower, upper, _) => + target: InternalRow => { + new SlidingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, lower, timeZone), + createBoundOrdering(frameType, upper, timeZone)) + } + + case _ => + throw new IllegalStateException(s"Unsupported factory: $key") + } + + // Keep track of the number of expressions. This is a side-effect in a map... + numExpressions += expressions.size + + // Create the Window Expression - Frame Factory pair. + (expressions, factory) + } + } + +} + +class WindowEvaluatorFactory( + val windowExpression: Seq[NamedExpression], + val partitionSpec: Seq[Expression], + val orderSpec: Seq[SortOrder], + val childOutput: Seq[Attribute], + val spillSize: SQLMetric) + extends PartitionEvaluatorFactory[InternalRow, InternalRow] with WindowEvaluatorFactoryBase { + + override def createEvaluator(): PartitionEvaluator[InternalRow, InternalRow] = { + new WindowPartitionEvaluator() + } + + class WindowPartitionEvaluator extends PartitionEvaluator[InternalRow, InternalRow] { + private val conf: SQLConf = SQLConf.get + + // Unwrap the window expressions and window frame factories from the map. + private val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1) + private val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray + private val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold + private val spillThreshold = conf.windowExecBufferSpillThreshold + + override def eval( + partitionIndex: Int, + inputs: Iterator[InternalRow]*): Iterator[InternalRow] = { + val stream = inputs.head + new Iterator[InternalRow] { + + // Get all relevant projections. + val result = createResultProjection(expressions) + val grouping = UnsafeProjection.create(partitionSpec, childOutput) + + // Manage the stream and the grouping. + var nextRow: UnsafeRow = null + var nextGroup: UnsafeRow = null + var nextRowAvailable: Boolean = false + private[this] def fetchNextRow(): Unit = { + nextRowAvailable = stream.hasNext + if (nextRowAvailable) { + nextRow = stream.next().asInstanceOf[UnsafeRow] + nextGroup = grouping(nextRow) + } else { + nextRow = null + nextGroup = null + } + } + fetchNextRow() + + // Manage the current partition. + val buffer: ExternalAppendOnlyUnsafeRowArray = + new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) + + var bufferIterator: Iterator[UnsafeRow] = _ + + val windowFunctionResult = new SpecificInternalRow(expressions.map(_.dataType)) + val frames = factories.map(_(windowFunctionResult)) + val numFrames = frames.length + private[this] def fetchNextPartition(): Unit = { + // Collect all the rows in the current partition. + // Before we start to fetch new input rows, make a copy of nextGroup. + val currentGroup = nextGroup.copy() + + // clear last partition + buffer.clear() + + while (nextRowAvailable && nextGroup == currentGroup) { + buffer.add(nextRow) + fetchNextRow() + } + + // Setup the frames. + var i = 0 + while (i < numFrames) { + frames(i).prepare(buffer) + i += 1 + } + + // Setup iteration + rowIndex = 0 + bufferIterator = buffer.generateIterator() + } + + // Iteration + var rowIndex = 0 + + override final def hasNext: Boolean = { + val found = (bufferIterator != null && bufferIterator.hasNext) || nextRowAvailable + if (!found) { + // clear final partition + buffer.clear() + spillSize += buffer.spillSize + } + found + } + + val join = new JoinedRow + override final def next(): InternalRow = { + // Load the next partition if we need to. + if ((bufferIterator == null || !bufferIterator.hasNext) && nextRowAvailable) { + fetchNextPartition() + } + + if (bufferIterator.hasNext) { + val current = bufferIterator.next() + + // Get the results for the window frames. + var i = 0 + while (i < numFrames) { + frames(i).write(rowIndex, current) + i += 1 + } + + // 'Merge' the input row with the window function result + join(current, windowFunctionResult) + rowIndex += 1 + + // Return the projection. + result(join) + } else { + throw new NoSuchElementException + } + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index dda5da6c9e9f4..35e59aef94fac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.window import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} /** @@ -95,111 +95,23 @@ case class WindowExec( ) protected override def doExecute(): RDD[InternalRow] = { - // Unwrap the window expressions and window frame factories from the map. - val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1) - val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray - val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold - val spillThreshold = conf.windowExecBufferSpillThreshold val spillSize = longMetric("spillSize") - // Start processing. - child.execute().mapPartitions { stream => - new Iterator[InternalRow] { - - // Get all relevant projections. - val result = createResultProjection(expressions) - val grouping = UnsafeProjection.create(partitionSpec, child.output) - - // Manage the stream and the grouping. - var nextRow: UnsafeRow = null - var nextGroup: UnsafeRow = null - var nextRowAvailable: Boolean = false - private[this] def fetchNextRow(): Unit = { - nextRowAvailable = stream.hasNext - if (nextRowAvailable) { - nextRow = stream.next().asInstanceOf[UnsafeRow] - nextGroup = grouping(nextRow) - } else { - nextRow = null - nextGroup = null - } - } - fetchNextRow() - - // Manage the current partition. - val buffer: ExternalAppendOnlyUnsafeRowArray = - new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) - - var bufferIterator: Iterator[UnsafeRow] = _ - - val windowFunctionResult = new SpecificInternalRow(expressions.map(_.dataType)) - val frames = factories.map(_(windowFunctionResult)) - val numFrames = frames.length - private[this] def fetchNextPartition(): Unit = { - // Collect all the rows in the current partition. - // Before we start to fetch new input rows, make a copy of nextGroup. - val currentGroup = nextGroup.copy() - - // clear last partition - buffer.clear() - - while (nextRowAvailable && nextGroup == currentGroup) { - buffer.add(nextRow) - fetchNextRow() - } - - // Setup the frames. - var i = 0 - while (i < numFrames) { - frames(i).prepare(buffer) - i += 1 - } + val evaluatorFactory = + new WindowEvaluatorFactory( + windowExpression, + partitionSpec, + orderSpec, + child.output, + spillSize) - // Setup iteration - rowIndex = 0 - bufferIterator = buffer.generateIterator() - } - - // Iteration - var rowIndex = 0 - - override final def hasNext: Boolean = { - val found = (bufferIterator != null && bufferIterator.hasNext) || nextRowAvailable - if (!found) { - // clear final partition - buffer.clear() - spillSize += buffer.spillSize - } - found - } - - val join = new JoinedRow - override final def next(): InternalRow = { - // Load the next partition if we need to. - if ((bufferIterator == null || !bufferIterator.hasNext) && nextRowAvailable) { - fetchNextPartition() - } - - if (bufferIterator.hasNext) { - val current = bufferIterator.next() - - // Get the results for the window frames. - var i = 0 - while (i < numFrames) { - frames(i).write(rowIndex, current) - i += 1 - } - - // 'Merge' the input row with the window function result - join(current, windowFunctionResult) - rowIndex += 1 - - // Return the projection. - result(join) - } else { - throw new NoSuchElementException - } - } + // Start processing. + if (conf.usePartitionEvaluator) { + child.execute().mapPartitionsWithEvaluator(evaluatorFactory) + } else { + child.execute().mapPartitions { iter => + val evaluator = evaluatorFactory.createEvaluator() + evaluator.eval(0, iter) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala index 82fc308e4095b..29f2256efc178 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala @@ -17,16 +17,9 @@ package org.apache.spark.sql.execution.window -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.UnaryExecNode -import org.apache.spark.sql.types._ -import org.apache.spark.util.collection.Utils /** * Holds common logic for window operators @@ -57,254 +50,4 @@ trait WindowExecBase extends UnaryExecNode { override def outputPartitioning: Partitioning = child.outputPartitioning - /** - * Create the resulting projection. - * - * This method uses Code Generation. It can only be used on the executor side. - * - * @param expressions unbound ordered function expressions. - * @return the final resulting projection. - */ - protected def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = { - val references = expressions.zipWithIndex.map { case (e, i) => - // Results of window expressions will be on the right side of child's output - BoundReference(child.output.size + i, e.dataType, e.nullable) - } - val unboundToRefMap = Utils.toMap(expressions, references) - val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) - UnsafeProjection.create( - child.output ++ patchedWindowExpression, - child.output) - } - - /** - * Create a bound ordering object for a given frame type and offset. A bound ordering object is - * used to determine which input row lies within the frame boundaries of an output row. - * - * This method uses Code Generation. It can only be used on the executor side. - * - * @param frame to evaluate. This can either be a Row or Range frame. - * @param bound with respect to the row. - * @param timeZone the session local timezone for time related calculations. - * @return a bound ordering object. - */ - private def createBoundOrdering( - frame: FrameType, bound: Expression, timeZone: String): BoundOrdering = { - (frame, bound) match { - case (RowFrame, CurrentRow) => - RowBoundOrdering(0) - - case (RowFrame, IntegerLiteral(offset)) => - RowBoundOrdering(offset) - - case (RowFrame, _) => - throw new IllegalStateException(s"Unhandled bound in windows expressions: $bound") - - case (RangeFrame, CurrentRow) => - val ordering = RowOrdering.create(orderSpec, child.output) - RangeBoundOrdering(ordering, IdentityProjection, IdentityProjection) - - case (RangeFrame, offset: Expression) if orderSpec.size == 1 => - // Use only the first order expression when the offset is non-null. - val sortExpr = orderSpec.head - val expr = sortExpr.child - - // Create the projection which returns the current 'value'. - val current = MutableProjection.create(expr :: Nil, child.output) - - // Flip the sign of the offset when processing the order is descending - val boundOffset = sortExpr.direction match { - case Descending => UnaryMinus(offset) - case Ascending => offset - } - - // Create the projection which returns the current 'value' modified by adding the offset. - val boundExpr = (expr.dataType, boundOffset.dataType) match { - case (DateType, IntegerType) => DateAdd(expr, boundOffset) - case (DateType, _: YearMonthIntervalType) => DateAddYMInterval(expr, boundOffset) - case (TimestampType | TimestampNTZType, CalendarIntervalType) => - TimeAdd(expr, boundOffset, Some(timeZone)) - case (TimestampType | TimestampNTZType, _: YearMonthIntervalType) => - TimestampAddYMInterval(expr, boundOffset, Some(timeZone)) - case (TimestampType | TimestampNTZType, _: DayTimeIntervalType) => - TimeAdd(expr, boundOffset, Some(timeZone)) - case (d: DecimalType, _: DecimalType) => DecimalAddNoOverflowCheck(expr, boundOffset, d) - case (a, b) if a == b => Add(expr, boundOffset) - } - val bound = MutableProjection.create(boundExpr :: Nil, child.output) - - // Construct the ordering. This is used to compare the result of current value projection - // to the result of bound value projection. This is done manually because we want to use - // Code Generation (if it is enabled). - val boundSortExprs = sortExpr.copy(BoundReference(0, expr.dataType, expr.nullable)) :: Nil - val ordering = RowOrdering.create(boundSortExprs, Nil) - RangeBoundOrdering(ordering, current, bound) - - case (RangeFrame, _) => - throw new IllegalStateException("Non-Zero range offsets are not supported for windows " + - "with multiple order expressions.") - } - } - - /** - * Collection containing an entry for each window frame to process. Each entry contains a frame's - * [[WindowExpression]]s and factory function for the [[WindowFunctionFrame]]. - */ - protected lazy val windowFrameExpressionFactoryPairs = { - type FrameKey = (String, FrameType, Expression, Expression, Seq[Expression]) - type ExpressionBuffer = mutable.Buffer[Expression] - val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)] - - // Add a function and its function to the map for a given frame. - def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = { - val key = fn match { - // This branch is used for Lead/Lag to support ignoring null and optimize the performance - // for NthValue ignoring null. - // All window frames move in rows. If there are multiple Leads, Lags or NthValues acting on - // a row and operating on different input expressions, they should not be moved uniformly - // by row. Therefore, we put these functions in different window frames. - case f: OffsetWindowFunction if f.ignoreNulls => - (tpe, fr.frameType, fr.lower, fr.upper, f.children.map(_.canonicalized)) - case _ => (tpe, fr.frameType, fr.lower, fr.upper, Nil) - } - val (es, fns) = framedFunctions.getOrElseUpdate( - key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression])) - es += e - fns += fn - } - - // Collect all valid window functions and group them by their frame. - windowExpression.foreach { x => - x.foreach { - case e @ WindowExpression(function, spec) => - val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] - function match { - case AggregateExpression(f, _, _, _, _) => collect("AGGREGATE", frame, e, f) - case f: FrameLessOffsetWindowFunction => - collect("FRAME_LESS_OFFSET", f.fakeFrame, e, f) - case f: OffsetWindowFunction if frame.frameType == RowFrame && - frame.lower == UnboundedPreceding => - frame.upper match { - case UnboundedFollowing => collect("UNBOUNDED_OFFSET", f.fakeFrame, e, f) - case CurrentRow => collect("UNBOUNDED_PRECEDING_OFFSET", f.fakeFrame, e, f) - case _ => collect("AGGREGATE", frame, e, f) - } - case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f) - case f => throw new IllegalStateException(s"Unsupported window function: $f") - } - case _ => - } - } - - // Map the groups to a (unbound) expression and frame factory pair. - var numExpressions = 0 - val timeZone = conf.sessionLocalTimeZone - framedFunctions.toSeq.map { - case (key, (expressions, functionSeq)) => - val ordinal = numExpressions - val functions = functionSeq.toArray - - // Construct an aggregate processor if we need one. - // Currently we don't allow mixing of Pandas UDF and SQL aggregation functions - // in a single Window physical node. Therefore, we can assume no SQL aggregation - // functions if Pandas UDF exists. In the future, we might mix Pandas UDF and SQL - // aggregation function in a single physical node. - def processor = if (functions.exists(_.isInstanceOf[PythonFuncExpression])) { - null - } else { - AggregateProcessor( - functions, - ordinal, - child.output, - (expressions, schema) => - MutableProjection.create(expressions, schema)) - } - - // Create the factory to produce WindowFunctionFrame. - val factory = key match { - // Frameless offset Frame - case ("FRAME_LESS_OFFSET", _, IntegerLiteral(offset), _, expr) => - target: InternalRow => - new FrameLessOffsetWindowFunctionFrame( - target, - ordinal, - // OFFSET frame functions are guaranteed be OffsetWindowFunction. - functions.map(_.asInstanceOf[OffsetWindowFunction]), - child.output, - (expressions, schema) => - MutableProjection.create(expressions, schema), - offset, - expr.nonEmpty) - case ("UNBOUNDED_OFFSET", _, IntegerLiteral(offset), _, expr) => - target: InternalRow => { - new UnboundedOffsetWindowFunctionFrame( - target, - ordinal, - // OFFSET frame functions are guaranteed be OffsetWindowFunction. - functions.map(_.asInstanceOf[OffsetWindowFunction]), - child.output, - (expressions, schema) => - MutableProjection.create(expressions, schema), - offset, - expr.nonEmpty) - } - case ("UNBOUNDED_PRECEDING_OFFSET", _, IntegerLiteral(offset), _, expr) => - target: InternalRow => { - new UnboundedPrecedingOffsetWindowFunctionFrame( - target, - ordinal, - // OFFSET frame functions are guaranteed be OffsetWindowFunction. - functions.map(_.asInstanceOf[OffsetWindowFunction]), - child.output, - (expressions, schema) => - MutableProjection.create(expressions, schema), - offset, - expr.nonEmpty) - } - - // Entire Partition Frame. - case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing, _) => - target: InternalRow => { - new UnboundedWindowFunctionFrame(target, processor) - } - - // Growing Frame. - case ("AGGREGATE", frameType, UnboundedPreceding, upper, _) => - target: InternalRow => { - new UnboundedPrecedingWindowFunctionFrame( - target, - processor, - createBoundOrdering(frameType, upper, timeZone)) - } - - // Shrinking Frame. - case ("AGGREGATE", frameType, lower, UnboundedFollowing, _) => - target: InternalRow => { - new UnboundedFollowingWindowFunctionFrame( - target, - processor, - createBoundOrdering(frameType, lower, timeZone)) - } - - // Moving Frame. - case ("AGGREGATE", frameType, lower, upper, _) => - target: InternalRow => { - new SlidingWindowFunctionFrame( - target, - processor, - createBoundOrdering(frameType, lower, timeZone), - createBoundOrdering(frameType, upper, timeZone)) - } - - case _ => - throw new IllegalStateException(s"Unsupported factory: $key") - } - - // Keep track of the number of expressions. This is a side-effect in a map... - numExpressions += expressions.size - - // Create the Window Expression - Frame Factory pair. - (expressions, factory) - } - } }